diff options
author | Jean-Marc Valin <jean-marc.valin@usherbrooke.ca> | 2008-09-22 06:33:14 +0400 |
---|---|---|
committer | Jean-Marc Valin <jean-marc.valin@usherbrooke.ca> | 2008-09-22 06:38:43 +0400 |
commit | b155bb8860e52e73f60a71f7c4acaadb4cb41a9c (patch) | |
tree | 765f9b26d10bc21e7c09d8139df5481b045964b6 | |
parent | 7bb339d9f92a466673ce49aa5bfda8847a7534e0 (diff) | |
parent | 5ee9715c5cb9694f0af6aa580ec9855b052d6d9b (diff) |
Merge branch 'cwrs_speedup'
Conflicts:
libcelt/cwrs.c
tests/cwrs32-test.c
-rw-r--r-- | libcelt/cwrs.c | 121 | ||||
-rw-r--r-- | libcelt/cwrs.h | 6 | ||||
-rw-r--r-- | tests/cwrs32-test.c | 8 | ||||
-rw-r--r-- | tests/cwrs64-test.c | 8 |
4 files changed, 49 insertions, 94 deletions
diff --git a/libcelt/cwrs.c b/libcelt/cwrs.c index 052ad0d..b3ff99d 100644 --- a/libcelt/cwrs.c +++ b/libcelt/cwrs.c @@ -193,138 +193,97 @@ static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){ /*Returns the number of ways of choosing _m elements from a set of size _n with replacement when a sign bit is needed for each unique element. - On input, _u should be initialized to column (_m-1) of U(n,m). - On exit, _u will be initialized to column _m of U(n,m).*/ -celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){ - celt_uint32_t ret; - celt_uint32_t ui0; - celt_uint32_t ui1; - int j; - ret=ui0=2; - celt_assert(_n>=2); - j=1; do { - ui1=_ui[j]+_ui[j-1]+ui0; - _ui[j-1]=ui0; - ui0=ui1; - ret+=ui0; - } while (++j<_n); - _ui[j-1]=ui0; - return ret; -} - -/*Returns the number of ways of choosing _m elements from a set of size _n with - replacement when a sign bit is needed for each unique element. - _u: On exit, _u[i] contains U(i+1,_m).*/ + _u: On exit, _u[i] contains U(_n,i) for i in [0..._m+1].*/ celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){ - celt_uint32_t ret; celt_uint32_t um2; int k; - /*If _m==0, _u[] should be set to zero and the return should be 1.*/ - celt_assert(_m>0); - /*We'll overflow our buffer unless _n>=2.*/ - celt_assert(_n>=2); - um2=_u[0]=1; - if(_m<=6){ - if(_m<2){ - k=1; - do _u[k]=1; - while(++k<_n); - } - else{ - k=1; - do _u[k]=(k<<1)+1; - while(++k<_n); - for(k=2;k<_m;k++)unext32(_u,_n,1); - } + int len; + len=_m+2; + _u[0]=0; + _u[1]=um2=1; + if(_n<=6){ + /*If _n==0, _u[0] should be 1 and the rest should be 0.*/ + /*If _n==1, _u[i] should be 1 for i>1.*/ + celt_assert(_n>=2); + /*If _m==0, the following do-while loop will overflow the buffer.*/ + celt_assert(_m>0); + k=2; + do _u[k]=(k<<1)-1; + while(++k<len); + for(k=2;k<_n;k++)unext32(_u+2,_m,(k<<1)+1); } else{ celt_uint32_t um1; celt_uint32_t n2m1; - _u[1]=n2m1=um1=(_m<<1)-1; - for(k=2;k<_n;k++){ + _u[2]=n2m1=um1=(_n<<1)-1; + for(k=3;k<len;k++){ /*U(n,m) = ((2*n-1)*U(n,m-1)-U(n,m-2))/(m-1) + U(n,m-2)*/ - _u[k]=um2=imusdiv32even(n2m1,um1,um2,k)+um2; - if(++k>=_n)break; - _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k>>1)+um1; + _u[k]=um2=imusdiv32even(n2m1,um1,um2,k-1)+um2; + if(++k>=len)break; + _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k-1>>1)+um1; } } - ret=1; - k=1; - do ret+=_u[k]; - while(++k<_n); - return ret<<1; + return _u[_m]+_u[_m+1]; } + /*Returns the _i'th combination of _m elements chosen from a set of size _n with associated sign bits. _y: Returns the vector of pulses. - _u: Must contain entries [1..._n] of column _m of U() on input. + _u: Must contain entries [0..._m+1] of row _n of U() on input. Its contents will be destructively modified.*/ -void cwrsi32(int _n,int _m,celt_uint32_t _i,celt_uint32_t _nc,int *_y, - celt_uint32_t *_u){ - celt_uint32_t p; - celt_uint32_t q; - int j; - int k; +void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_y,celt_uint32_t *_u){ + int j; + int k; celt_assert(_n>0); - p=_nc; - q=0; j=0; k=_m; do{ - int s; - int yj; - p-=q; - q=_u[_n-j-1]; - p-=q; + celt_uint32_t p; + int s; + int yj; + p=_u[k+1]; s=_i>=p; if(s)_i-=p; yj=k; - while(q>_i){ - uprev32(_u,_n-j,--k>0); - p=q; - q=_u[_n-j-1]; - } - _i-=q; + p=_u[k]; + while(p>_i)p=_u[--k]; + _i-=p; yj-=k; _y[j]=yj-(yj<<1&-s); + uprev32(_u,k+2,0); } while(++j<_n); } + /*Returns the index of the given combination of _m elements chosen from a set of size _n with associated sign bits. _y: The vector of pulses, whose sum of absolute values must be _m. _nc: Returns V(_n,_m).*/ celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y, celt_uint32_t *_u){ - celt_uint32_t nc; celt_uint32_t i; int j; int k; /*We can't unroll the first two iterations of the loop unless _n>=2.*/ celt_assert(_n>=2); - nc=1; i=_y[_n-1]<0; _u[0]=0; for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1; k=abs(_y[_n-1]); j=_n-2; - nc+=_u[_m]; i+=_u[k]; k+=abs(_y[j]); if(_y[j]<0)i+=_u[k+1]; while(j-->0){ unext32(_u,_m+2,0); - nc+=_u[_m]; i+=_u[k]; k+=abs(_y[j]); if(_y[j]<0)i+=_u[k+1]; } - /*If _m==0, nc should not be doubled.*/ - celt_assert(_m>0); - *_nc=nc<<1; + *_nc=_u[_m]+_u[_m+1]; return i; } @@ -346,7 +305,7 @@ int get_required_bits(int N, int K, int frac) { VARDECL(celt_uint32_t,u); SAVE_STACK; - ALLOC(u,N,celt_uint32_t); + ALLOC(u,K+2,celt_uint32_t); nbits = log2_frac(ncwrs_u32(N,K,u), frac); RESTORE_STACK; } else { @@ -382,11 +341,9 @@ void encode_pulses(int *_y, int N, int K, ec_enc *enc) static inline void decode_pulse32(int _n,int _m,int *_y,ec_dec *_dec){ VARDECL(celt_uint32_t,u); - celt_uint32_t nc; SAVE_STACK; - ALLOC(u,_n,celt_uint32_t); - nc=ncwrs_u32(_n,_m,u); - cwrsi32(_n,_m,ec_dec_uint(_dec,nc),nc,_y,u); + ALLOC(u,_m+2,celt_uint32_t); + cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_y,u); RESTORE_STACK; } diff --git a/libcelt/cwrs.h b/libcelt/cwrs.h index 0a81c61..5c0d50a 100644 --- a/libcelt/cwrs.h +++ b/libcelt/cwrs.h @@ -46,8 +46,7 @@ int fits_in64(int _n, int _m); /* 32-bit versions */ celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u); -void cwrsi32(int _n,int _m,celt_uint32_t _i,celt_uint32_t _nc,int *_y, - celt_uint32_t *_u); +void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_y,celt_uint32_t *_u); celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y, celt_uint32_t *_u); @@ -55,8 +54,7 @@ celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y, /* 64-bit versions */ celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u); -void cwrsi64(int _n,int _m,celt_uint64_t _i,celt_uint64_t _nc,int *_y, - celt_uint64_t *_u); +void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_y,celt_uint64_t *_u); celt_uint64_t icwrs64(int _n,int _m,celt_uint64_t *_nc,const int *_y, celt_uint64_t *_u); diff --git a/tests/cwrs32-test.c b/tests/cwrs32-test.c index 9b6aa97..b6ef23d 100644 --- a/tests/cwrs32-test.c +++ b/tests/cwrs32-test.c @@ -13,7 +13,7 @@ int main(int _argc,char **_argv){ for(n=2;n<=NMAX;n++){ int m; for(m=1;m<=MMAX;m++){ - celt_uint32_t uu[NMAX]; + celt_uint32_t uu[MMAX+2]; celt_uint32_t inc; celt_uint32_t nc; celt_uint32_t i; @@ -21,11 +21,11 @@ int main(int _argc,char **_argv){ inc=nc/10000; if(inc<1)inc=1; for(i=0;i<nc;i+=inc){ - celt_uint32_t u[NMAX>MMAX+2?NMAX:MMAX+2]; + celt_uint32_t u[MMAX+2]; int y[NMAX]; celt_uint32_t v; - memcpy(u,uu,n*sizeof(*u)); - cwrsi32(n,m,i,nc,y,u); + memcpy(u,uu,(m+2)*sizeof(*u)); + cwrsi32(n,m,i,y,u); /*printf("%6u of %u:",i,nc); for(k=0;k<n;k++)printf(" %+3i",y[k]); printf(" ->");*/ diff --git a/tests/cwrs64-test.c b/tests/cwrs64-test.c index e699deb..548fa22 100644 --- a/tests/cwrs64-test.c +++ b/tests/cwrs64-test.c @@ -14,7 +14,7 @@ int main(int _argc,char **_argv){ for(n=2;n<=NMAX;n+=3){ int m; for(m=1;m<=MMAX;m++){ - celt_uint64_t uu[NMAX]; + celt_uint64_t uu[MMAX+2]; celt_uint64_t inc; celt_uint64_t nc; celt_uint64_t i; @@ -24,12 +24,12 @@ int main(int _argc,char **_argv){ if(inc<1)inc=1; /*printf("%d/%d: %llu",n,m, nc);*/ for(i=0;i<nc;i+=inc){ - celt_uint64_t u[NMAX>MMAX+2?NMAX:MMAX+2]; + celt_uint64_t u[MMAX+2]; int y[NMAX]; celt_uint64_t v; int k; - memcpy(u,uu,n*sizeof(*u)); - cwrsi64(n,m,i,nc,y,u); + memcpy(u,uu,(m+2)*sizeof(*u)); + cwrsi64(n,m,i,y,u); /*printf("%llu of %llu:",i,nc); for(k=0;k<n;k++)printf(" %+3i",y[k]); printf(" ->");*/ |