diff options
author | Luca Antiga <luca.antiga@orobix.com> | 2017-06-07 03:21:20 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-07-04 21:55:19 +0300 |
commit | e78e9aacd4894e65e3ce3524c517d2a39d12f338 (patch) | |
tree | ea6d50891010af1147619e9adae45cb29b23686f | |
parent | 6b32a6149564762a6028703f7655c8ef00dd4852 (diff) |
Have median reduce over all dims and return just the value when dim is not provided
-rw-r--r-- | lib/TH/generic/THTensorMath.c | 78 | ||||
-rw-r--r-- | lib/TH/generic/THTensorMath.h | 1 |
2 files changed, 79 insertions, 0 deletions
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index b955319..215c04c 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -586,6 +586,34 @@ real THTensor_(maxall)(THTensor *tensor) return theMax; } +static void THTensor_(quickselectnoidx)(real *arr, long k, long elements, long stride); + +real THTensor_(medianall)(THTensor *tensor) +{ + THArgCheck(tensor->nDimension > 0, 1, "tensor must have one dimension"); + THArgCheck(THTensor_(isContiguous)(tensor), 1, "input is not contiguous"); + + real theMedian; + ptrdiff_t numel; + long k; + THTensor *temp_; + real *temp__data; + + numel = THTensor_(nElement)(tensor); + k = (numel-1) >> 1; + + temp_ = THTensor_(newClone)(tensor); + temp__data = THTensor_(data)(temp_); + + THTensor_(quickselectnoidx)(temp__data, k, numel, 1); + + theMedian = temp__data[k]; + + THTensor_(free)(temp_); + + return theMedian; +} + accreal THTensor_(sumall)(THTensor *tensor) { accreal sum = 0; @@ -2044,6 +2072,9 @@ void THTensor_(reshape)(THTensor *r_, THTensor *t, THLongStorage *size) #define LONG_SWAP(AAA, BBB) swap = AAA; AAA = BBB; BBB = swap #define REAL_SWAP(AAA, BBB) rswap = AAA; AAA = BBB; BBB = rswap +#define ARR_SWAP(III, JJJ) \ + REAL_SWAP(ARR(III), ARR(JJJ)); + #define BOTH_SWAP(III, JJJ) \ REAL_SWAP(ARR(III), ARR(JJJ)); \ LONG_SWAP(IDX(III), IDX(JJJ)) @@ -2263,6 +2294,53 @@ void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimensio /* Implementation of the Quickselect algorithm, based on Nicolas Devillard's public domain implementation at http://ndevilla.free.fr/median/median/ +Adapted similarly to the above Quicksort algorithm. +This version does not produce indices along with values. */ +static void THTensor_(quickselectnoidx)(real *arr, long k, long elements, long stride) +{ + long P, L, R, i, j, swap; + real rswap, piv; + L = 0; + R = elements-1; + + do { + if (R <= L) /* One element only */ + return; + + if (R == L+1) { /* Two elements only */ + if (ARR(L) > ARR(R)) { + ARR_SWAP(L, R); + } + return; + } + + /* Use median of three for pivot choice */ + P=(L+R)>>1; + ARR_SWAP(P, L+1); + if (ARR(L+1) > ARR(R)) { ARR_SWAP(L+1, R); } + if (ARR(L) > ARR(R)) { ARR_SWAP(L, R); } + if (ARR(L+1) > ARR(L)) { ARR_SWAP(L+1, L); } + + i = L+1; + j = R; + piv = ARR(L); + do { + do i++; while(ARR(i) < piv); + do j--; while(ARR(j) > piv); + if (j < i) + break; + ARR_SWAP(i, j); + } while(1); + ARR_SWAP(L, j); + + /* Re-set active partition */ + if (j <= k) L=i; + if (j >= k) R=j-1; + } while(1); +} + +/* Implementation of the Quickselect algorithm, based on Nicolas Devillard's +public domain implementation at http://ndevilla.free.fr/median/median/ Adapted similarly to the above Quicksort algorithm. */ static void THTensor_(quickselect)(real *arr, long *idx, long k, long elements, long stride) { diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h index 6337533..17e54cc 100644 --- a/lib/TH/generic/THTensorMath.h +++ b/lib/TH/generic/THTensorMath.h @@ -25,6 +25,7 @@ TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src); TH_API real THTensor_(minall)(THTensor *t); TH_API real THTensor_(maxall)(THTensor *t); +TH_API real THTensor_(medianall)(THTensor *t); TH_API accreal THTensor_(sumall)(THTensor *t); TH_API accreal THTensor_(prodall)(THTensor *t); |