diff options
author | Timothy B. Terriberry <tterribe@xiph.org> | 2008-04-04 18:16:19 +0400 |
---|---|---|
committer | Jean-Marc Valin <jean-marc.valin@usherbrooke.ca> | 2008-04-05 08:31:35 +0400 |
commit | d883670bf792ad02750fe481646bcd8b4cf6ad72 (patch) | |
tree | dd5940d3fd73e87c898526527a0d48e230519b43 /libcelt/cwrs.c | |
parent | ae76e553db3d411b5ebb9a8b6df28c18d3a30f82 (diff) |
Rework CWRS code.
This eliminates an extra O(nm) lookups on decode, and reduces the rate control
from O(nm^2) to O(nm), in addition to eliminating O(m) lookups on both encode
and decode.
Although the interface is slightly more complex, the internal code is also
simpler.
Diffstat (limited to 'libcelt/cwrs.c')
-rw-r--r-- | libcelt/cwrs.c | 430 |
1 files changed, 199 insertions, 231 deletions
diff --git a/libcelt/cwrs.c b/libcelt/cwrs.c index 12835de..60880c6 100644 --- a/libcelt/cwrs.c +++ b/libcelt/cwrs.c @@ -1,4 +1,4 @@ -/* (C) 2007 Timothy B. Terriberry +/* (C) 2007-2008 Timothy B. Terriberry (C) 2008 Jean-Marc Valin */ /* Redistribution and use in source and binary forms, with or without @@ -29,8 +29,13 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -/* Functions for encoding and decoding pulse vectors. For more details, see: - http://people.xiph.org/~tterribe/notes/cwrs.html +/* Functions for encoding and decoding pulse vectors. + These are based on the function + U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1), + U(n,1) = U(1,m) = 2, + which counts the number of ways of placing m pulses in n dimensions, where + at least one pulse lies in dimension 0. + For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html */ #ifdef HAVE_CONFIG_H @@ -38,133 +43,130 @@ #endif #include <stdlib.h> +#include <string.h> #include "cwrs.h" #include "mathops.h" -/* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n, - compute ncwrs() for m+1, for all n. Could also be used when m and n are - swapped just by changing nc */ -static inline void next_ncwrs32(celt_uint32_t *nc, int len, int nc0) -{ - int i; - celt_uint32_t mem; - - mem = nc[0]; - nc[0] = nc0; - for (i=1;i<len;i++) - { - celt_uint32_t tmp = nc[i]+nc[i-1]+mem; - mem = nc[i]; - nc[i] = tmp; - } +/*Computes the next row/column of any recurrence that obeys the relation + u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1]. + _ui0 is the base case for the new row/column.*/ +static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){ + celt_uint32_t ui1; + int j; + for(j=1;j<_len;j++){ + ui1=_ui[j]+_ui[j-1]+_ui0; + _ui[j-1]=_ui0; + _ui0=ui1; + } + _ui[j-1]=_ui0; } -/* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n, - compute ncwrs() for m-1, for all n. Could also be used when m and n are - swapped just by changing nc */ -static inline void prev_ncwrs32(celt_uint32_t *nc, int len, int nc0) -{ - int i; - celt_uint32_t mem; - - mem = nc[0]; - nc[0] = nc0; - for (i=1;i<len;i++) - { - celt_uint32_t tmp = nc[i]-nc[i-1]-mem; - mem = nc[i]; - nc[i] = tmp; - } +static inline void unext64(celt_uint64_t *_ui,int _len,celt_uint64_t _ui0){ + celt_uint64_t ui1; + int j; + for(j=1;j<_len;j++){ + ui1=_ui[j]+_ui[j-1]+_ui0; + _ui[j-1]=_ui0; + _ui0=ui1; + } + _ui[j-1]=_ui0; } -static inline void next_ncwrs64(celt_uint64_t *nc, int len, int nc0) -{ - int i; - celt_uint64_t mem; - - mem = nc[0]; - nc[0] = nc0; - for (i=1;i<len;i++) - { - celt_uint64_t tmp = nc[i]+nc[i-1]+mem; - mem = nc[i]; - nc[i] = tmp; - } +/*Computes the previous row/column of any recurrence that obeys the relation + u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1]. + _ui0 is the base case for the new row/column.*/ +static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){ + celt_uint32_t ui1; + int j; + for(j=1;j<_n;j++){ + ui1=_ui[j]-_ui[j-1]-_ui0; + _ui[j-1]=_ui0; + _ui0=ui1; + } + _ui[j-1]=_ui0; } -static inline void prev_ncwrs64(celt_uint64_t *nc, int len, int nc0) -{ - int i; - celt_uint64_t mem; - - mem = nc[0]; - nc[0] = nc0; - for (i=1;i<len;i++) - { - celt_uint64_t tmp = nc[i]-nc[i-1]-mem; - mem = nc[i]; - nc[i] = tmp; - } +static inline void uprev64(celt_uint64_t *_ui,int _n,celt_uint64_t _ui0){ + celt_uint64_t ui1; + int j; + for(j=1;j<_n;j++){ + ui1=_ui[j]-_ui[j-1]-_ui0; + _ui[j-1]=_ui0; + _ui0=ui1; + } + _ui[j-1]=_ui0; } -/*Returns the numer of ways of choosing _m elements from a set of size _n with - replacement when a sign bit is needed for each unique element.*/ -celt_uint32_t ncwrs(int _n,int _m) -{ - int i; - celt_uint32_t ret; - VARDECL(celt_uint32_t, nc); - SAVE_STACK; - ALLOC(nc,_n+1, celt_uint32_t); - for (i=0;i<_n+1;i++) - nc[i] = 1; - for (i=0;i<_m;i++) - next_ncwrs32(nc, _n+1, 0); - ret = nc[_n]; - RESTORE_STACK; - return ret; +/*Returns the number of ways of choosing _m elements from a set of size _n with + replacement when a sign bit is needed for each unique element. + On input, _u should be initialized to column (_m-1) of U(n,m). + On exit, _u will be initialized to column _m of U(n,m).*/ +celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){ + celt_uint32_t ret; + celt_uint32_t ui0; + celt_uint32_t ui1; + int j; + ret=ui0=2; + for(j=1;j<_n;j++){ + ui1=_ui[j]+_ui[j-1]+ui0; + _ui[j-1]=ui0; + ui0=ui1; + ret+=ui0; + } + _ui[j-1]=ui0; + return ret; } -/*Returns the numer of ways of choosing _m elements from a set of size _n with - replacement when a sign bit is needed for each unique element.*/ -celt_uint64_t ncwrs64(int _n,int _m) -{ - int i; - celt_uint64_t ret; - VARDECL(celt_uint64_t, nc); - SAVE_STACK; - ALLOC(nc,_n+1, celt_uint64_t); - for (i=0;i<_n+1;i++) - nc[i] = 1; - for (i=0;i<_m;i++) - next_ncwrs64(nc, _n+1, 0); - ret = nc[_n]; - RESTORE_STACK; - return ret; +celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_ui){ + celt_uint64_t ret; + celt_uint64_t ui0; + celt_uint64_t ui1; + int j; + ret=ui0=2; + for(j=1;j<_n;j++){ + ui1=_ui[j]+_ui[j-1]+ui0; + _ui[j-1]=ui0; + ui0=ui1; + ret+=ui0; + } + _ui[j-1]=ui0; + return ret; } +/*Returns the number of ways of choosing _m elements from a set of size _n with + replacement when a sign bit is needed for each unique element. + On exit, _u will be initialized to column _m of U(n,m).*/ +celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){ + int k; + memset(_u,0,_n*sizeof(*_u)); + if(_m<=0)return 1; + if(_n<=0)return 0; + for(k=1;k<_m;k++)unext32(_u,_n,2); + return ncwrs_unext32(_n,_u); +} + +celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u){ + int k; + memset(_u,0,_n*sizeof(*_u)); + if(_m<=0)return 1; + if(_n<=0)return 0; + for(k=1;k<_m;k++)unext64(_u,_n,2); + return ncwrs_unext64(_n,_u); +} /*Returns the _i'th combination of _m elements chosen from a set of size _n with associated sign bits. - _x: Returns the combination with elements sorted in ascending order. - _s: Returns the associated sign bits.*/ -void cwrsi(int _n,int _m,celt_uint32_t _i,int * restrict _x,int * restrict _s){ + _x: Returns the combination with elements sorted in ascending order. + _s: Returns the associated sign bits. + _u: Temporary storage already initialized to column _m of U(n,m). + Its contents will be overwritten.*/ +void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,celt_uint32_t *_u){ int j; int k; - VARDECL(celt_uint32_t, nc); - SAVE_STACK; - ALLOC(nc,_n+1, celt_uint32_t); - for (j=0;j<_n+1;j++) - nc[j] = 1; - for (k=0;k<_m-1;k++) - next_ncwrs32(nc, _n+1, 0); for(k=j=0;k<_m;k++){ - celt_uint32_t pn, p, t; - /*p=ncwrs(_n-j,_m-k-1); - pn=ncwrs(_n-j-1,_m-k-1);*/ - p=nc[_n-j]; - pn=nc[_n-j-1]; - p+=pn; + celt_uint32_t p; + celt_uint32_t t; + p=_u[_n-j-1]; if(k>0){ t=p>>1; if(t<=_i||_s[k-1])_i+=t; @@ -172,89 +174,23 @@ void cwrsi(int _n,int _m,celt_uint32_t _i,int * restrict _x,int * restrict _s){ while(p<=_i){ _i-=p; j++; - p=pn; - /*pn=ncwrs(_n-j-1,_m-k-1);*/ - pn=nc[_n-j-1]; - p+=pn; + p=_u[_n-j-1]; } t=p>>1; _s[k]=_i>=t; _x[k]=j; if(_s[k])_i-=t; - if (k<_m-2) - prev_ncwrs32(nc, _n-j+1, 0); - else - prev_ncwrs32(nc, _n-j+1, 1); - } - RESTORE_STACK; -} - -/*Returns the index of the given combination of _m elements chosen from a set - of size _n with associated sign bits. - _x: The combination with elements sorted in ascending order. - _s: The associated sign bits.*/ -celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s, celt_uint32_t *bound){ - celt_uint32_t i; - int j; - int k; - VARDECL(celt_uint32_t, nc); - SAVE_STACK; - ALLOC(nc,_n+1, celt_uint32_t); - for (j=0;j<_n+1;j++) - nc[j] = 1; - for (k=0;k<_m;k++) - next_ncwrs32(nc, _n+1, 0); - if (bound) - *bound = nc[_n]; - i=0; - for(k=j=0;k<_m;k++){ - celt_uint32_t pn; - celt_uint32_t p; - if (k<_m-1) - prev_ncwrs32(nc, _n-j+1, 0); - else - prev_ncwrs32(nc, _n-j+1, 1); - /*p=ncwrs(_n-j,_m-k-1); - pn=ncwrs(_n-j-1,_m-k-1);*/ - p=nc[_n-j]; - pn=nc[_n-j-1]; - p+=pn; - if(k>0)p>>=1; - while(j<_x[k]){ - i+=p; - j++; - p=pn; - /*pn=ncwrs(_n-j-1,_m-k-1);*/ - pn=nc[_n-j-1]; - p+=pn; - } - if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1; + uprev32(_u,_n-j,2); } - RESTORE_STACK; - return i; } -/*Returns the _i'th combination of _m elements chosen from a set of size _n - with associated sign bits. - _x: Returns the combination with elements sorted in ascending order. - _s: Returns the associated sign bits.*/ -void cwrsi64(int _n,int _m,celt_uint64_t _i,int * restrict _x,int * restrict _s){ +void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s,celt_uint64_t *_u){ int j; int k; - VARDECL(celt_uint64_t, nc); - SAVE_STACK; - ALLOC(nc,_n+1, celt_uint64_t); - for (j=0;j<_n+1;j++) - nc[j] = 1; - for (k=0;k<_m-1;k++) - next_ncwrs64(nc, _n+1, 0); for(k=j=0;k<_m;k++){ - celt_uint64_t pn, p, t; - /*p=ncwrs64(_n-j,_m-k-1); - pn=ncwrs64(_n-j-1,_m-k-1);*/ - p=nc[_n-j]; - pn=nc[_n-j-1]; - p+=pn; + celt_uint64_t p; + celt_uint64_t t; + p=_u[_n-j-1]; if(k>0){ t=p>>1; if(t<=_i||_s[k-1])_i+=t; @@ -262,65 +198,61 @@ void cwrsi64(int _n,int _m,celt_uint64_t _i,int * restrict _x,int * restrict _s) while(p<=_i){ _i-=p; j++; - p=pn; - /*pn=ncwrs64(_n-j-1,_m-k-1);*/ - pn=nc[_n-j-1]; - p+=pn; + p=_u[_n-j-1]; } t=p>>1; _s[k]=_i>=t; _x[k]=j; if(_s[k])_i-=t; - if (k<_m-2) - prev_ncwrs64(nc, _n-j+1, 0); - else - prev_ncwrs64(nc, _n-j+1, 1); + uprev64(_u,_n-j,2); } - RESTORE_STACK; } /*Returns the index of the given combination of _m elements chosen from a set of size _n with associated sign bits. - _x: The combination with elements sorted in ascending order. - _s: The associated sign bits.*/ -celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, celt_uint64_t *bound){ + _x: The combination with elements sorted in ascending order. + _s: The associated sign bits. + _u: Temporary storage already initialized to column _m of U(n,m). + Its contents will be overwritten.*/ +celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s, + celt_uint32_t *_u){ + celt_uint32_t i; + int j; + int k; + i=0; + for(k=j=0;k<_m;k++){ + celt_uint32_t p; + p=_u[_n-j-1]; + if(k>0)p>>=1; + while(j<_x[k]){ + i+=p; + j++; + p=_u[_n-j-1]; + } + if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1; + uprev32(_u,_n-j,2); + } + return i; +} + +celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, + celt_uint64_t *_u){ celt_uint64_t i; int j; int k; - VARDECL(celt_uint64_t, nc); - SAVE_STACK; - ALLOC(nc,_n+1, celt_uint64_t); - for (j=0;j<_n+1;j++) - nc[j] = 1; - for (k=0;k<_m;k++) - next_ncwrs64(nc, _n+1, 0); - if (bound) - *bound = nc[_n]; i=0; for(k=j=0;k<_m;k++){ - celt_uint64_t pn; celt_uint64_t p; - if (k<_m-1) - prev_ncwrs64(nc, _n-j+1, 0); - else - prev_ncwrs64(nc, _n-j+1, 1); - /*p=ncwrs64(_n-j,_m-k-1); - pn=ncwrs64(_n-j-1,_m-k-1);*/ - p=nc[_n-j]; - pn=nc[_n-j-1]; - p+=pn; + p=_u[_n-j-1]; if(k>0)p>>=1; while(j<_x[k]){ i+=p; j++; - p=pn; - /*pn=ncwrs64(_n-j-1,_m-k-1);*/ - pn=nc[_n-j-1]; - p+=pn; + p=_u[_n-j-1]; } if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1; + uprev64(_u,_n-j,2); } - RESTORE_STACK; return i; } @@ -363,47 +295,83 @@ void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){ } } +static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s, + ec_enc *_enc){ + VARDECL(celt_uint32_t,u); + celt_uint32_t nc; + celt_uint32_t i; + SAVE_STACK; + ALLOC(u,_n,celt_uint32_t); + nc=ncwrs_u32(_n,_m,u); + i=icwrs32(_n,_m,_x,_s,u); + ec_enc_uint(_enc,i,nc); + RESTORE_STACK; +} + +static inline void encode_comb64(int _n,int _m,const int *_x,const int *_s, + ec_enc *_enc){ + VARDECL(celt_uint64_t,u); + celt_uint64_t nc; + celt_uint64_t i; + SAVE_STACK; + ALLOC(u,_n,celt_uint64_t); + nc=ncwrs_u64(_n,_m,u); + i=icwrs64(_n,_m,_x,_s,u); + ec_enc_uint64(_enc,i,nc); + RESTORE_STACK; +} + void encode_pulses(int *_y, int N, int K, ec_enc *enc) { VARDECL(int, comb); VARDECL(int, signs); SAVE_STACK; - + ALLOC(comb, K, int); ALLOC(signs, K, int); - + pulse2comb(N, K, comb, signs, _y); /* Simple heuristic to figure out whether it fits in 32 bits */ if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31) { - celt_uint32_t bound, id; - id = icwrs(N, K, comb, signs, &bound); - ec_enc_uint(enc,id,bound); + encode_comb32(N, K, comb, signs, enc); } else { - celt_uint64_t bound, id; - id = icwrs64(N, K, comb, signs, &bound); - ec_enc_uint64(enc,id,bound); + encode_comb64(N, K, comb, signs, enc); } RESTORE_STACK; } +static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){ + VARDECL(celt_uint32_t,u); + SAVE_STACK; + ALLOC(u,_n,celt_uint32_t); + cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u); + RESTORE_STACK; +} + +static inline void decode_comb64(int _n,int _m,int *_x,int *_s,ec_dec *_dec){ + VARDECL(celt_uint64_t,u); + SAVE_STACK; + ALLOC(u,_n,celt_uint64_t); + cwrsi64(_n,_m,ec_dec_uint64(_dec,ncwrs_u64(_n,_m,u)),_x,_s,u); + RESTORE_STACK; +} + void decode_pulses(int *_y, int N, int K, ec_dec *dec) { VARDECL(int, comb); VARDECL(int, signs); SAVE_STACK; - + ALLOC(comb, K, int); ALLOC(signs, K, int); /* Simple heuristic to figure out whether it fits in 32 bits */ if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31) { - cwrsi(N, K, ec_dec_uint(dec, ncwrs(N, K)), comb, signs); - comb2pulse(N, K, _y, comb, signs); + decode_comb32(N, K, comb, signs, dec); } else { - cwrsi64(N, K, ec_dec_uint64(dec, ncwrs64(N, K)), comb, signs); - comb2pulse(N, K, _y, comb, signs); + decode_comb64(N, K, comb, signs, dec); } + comb2pulse(N, K, _y, comb, signs); RESTORE_STACK; } - |