diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-12-18 13:33:57 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-12-18 13:33:57 +0300 |
commit | 41ef7686b5ae84045c83947e2566f88566d8a98f (patch) | |
tree | f4f664110466b163c68b56244d2647e80d2b0f1d | |
parent | 29c4da0fc312af506099ec2a8deac846a9aa08c8 (diff) |
moved OSCEModel struct one level up
-rw-r--r-- | dnn/nndsp.c | 13 | ||||
-rw-r--r-- | dnn/osce.c | 62 | ||||
-rw-r--r-- | dnn/osce.h | 4 | ||||
-rw-r--r-- | silk/dec_API.c | 22 | ||||
-rw-r--r-- | silk/decode_frame.c | 5 | ||||
-rw-r--r-- | silk/init_decoder.c | 6 | ||||
-rw-r--r-- | silk/main.h | 3 | ||||
-rw-r--r-- | silk/structs.h | 1 | ||||
-rw-r--r-- | src/opus_decoder.c | 2 |
9 files changed, 61 insertions, 57 deletions
diff --git a/dnn/nndsp.c b/dnn/nndsp.c index 6a5fa44f..d5c3dc8f 100644 --- a/dnn/nndsp.c +++ b/dnn/nndsp.c @@ -42,7 +42,6 @@ #define M_PI 3.141592653589793f #endif -#define SET_ZERO(x) memset(x, 0, sizeof(x)) #define KERNEL_INDEX(i_out_channels, i_in_channels, i_kernel) ((((i_out_channels) * in_channels) + (i_in_channels)) * kernel_size + (i_kernel)) void init_adaconv_state(AdaConvState *hAdaConv) @@ -168,9 +167,9 @@ void adaconv_process_frame( celt_assert(left_padding == kernel_size - 1); /* currently only supports causal version. Non-causal version not difficult to implement but will require third loop */ celt_assert(kernel_size < frame_size); - SET_ZERO(output_buffer); - SET_ZERO(kernel_buffer); - SET_ZERO(input_buffer); + OPUS_CLEAR(output_buffer, ADACONV_MAX_FRAME_SIZE * ADACONV_MAX_OUTPUT_CHANNELS); + OPUS_CLEAR(kernel_buffer, ADACONV_MAX_KERNEL_SIZE * ADACONV_MAX_INPUT_CHANNELS * ADACONV_MAX_OUTPUT_CHANNELS); + OPUS_CLEAR(input_buffer, ADACONV_MAX_INPUT_CHANNELS * (ADACONV_MAX_FRAME_SIZE + ADACONV_MAX_KERNEL_SIZE)); #ifdef DEBUG_NNDSP print_float_vector("x_in", x_in, in_channels * frame_size); @@ -273,9 +272,9 @@ void adacomb_process_frame( (void) feature_dim; /* ToDo: figure out whether we might need this information */ - SET_ZERO(output_buffer); - SET_ZERO(kernel_buffer); - SET_ZERO(input_buffer); + OPUS_CLEAR(output_buffer, ADACOMB_MAX_FRAME_SIZE); + OPUS_CLEAR(kernel_buffer, ADACOMB_MAX_KERNEL_SIZE); + OPUS_CLEAR(input_buffer, ADACOMB_MAX_FRAME_SIZE + ADACOMB_MAX_LAG + ADACOMB_MAX_KERNEL_SIZE); OPUS_COPY(input_buffer, hAdaComb->history, kernel_size + ADACOMB_MAX_LAG); OPUS_COPY(input_buffer + kernel_size + ADACOMB_MAX_LAG, x_in, frame_size); @@ -36,6 +36,7 @@ #include "os_support.h" #include "nndsp.h" #include "float_cast.h" +#include "arch.h" #ifdef OSCE_DEBUG #include <stdio.h> @@ -50,9 +51,6 @@ #endif #define CLIP(a, min, max) (((a) < (min) ? (min) : (a)) > (max) ? (max) : (a)) -#define MAX(a, b) ((a) < (b) ? (b) : (a)) - - extern const WeightArray lacelayers_arrays[]; extern const WeightArray nolacelayers_arrays[]; @@ -112,8 +110,8 @@ static void lace_feature_net( int arch ) { - float input_buffer[4 * MAX(LACE_COND_DIM, LACE_HIDDEN_FEATURE_DIM)]; - float output_buffer[4 * MAX(LACE_COND_DIM, LACE_HIDDEN_FEATURE_DIM)]; + float input_buffer[4 * IMAX(LACE_COND_DIM, LACE_HIDDEN_FEATURE_DIM)]; + float output_buffer[4 * IMAX(LACE_COND_DIM, LACE_HIDDEN_FEATURE_DIM)]; float numbits_embedded[2 * LACE_NUMBITS_EMBEDDING_DIM]; int i_subframe; @@ -383,8 +381,8 @@ static void nolace_feature_net( int arch ) { - float input_buffer[4 * MAX(NOLACE_COND_DIM, NOLACE_HIDDEN_FEATURE_DIM)]; - float output_buffer[4 * MAX(NOLACE_COND_DIM, NOLACE_HIDDEN_FEATURE_DIM)]; + float input_buffer[4 * IMAX(NOLACE_COND_DIM, NOLACE_HIDDEN_FEATURE_DIM)]; + float output_buffer[4 * IMAX(NOLACE_COND_DIM, NOLACE_HIDDEN_FEATURE_DIM)]; float numbits_embedded[2 * NOLACE_NUMBITS_EMBEDDING_DIM]; int i_subframe; @@ -815,49 +813,47 @@ void osce_reset(silk_OSCE_struct *hOSCE, int method) } -void osce_init(silk_OSCE_struct *hOSCE, int method) +int osce_load_models(OSCEModel *model, const unsigned char *data, int len) { -#ifndef USE_WEIGHTS_FILE - /* initialize all models */ -#ifndef DISABLE_LACE - init_lace(&hOSCE->model.lace, lacelayers_arrays); -#endif + int ret = 0; + WeightArray *list; -#ifndef DISABLE_NOLACE - init_nolace(&hOSCE->model.nolace, nolacelayers_arrays); + if (data != NULL && len) + { + /* init from buffer */ + parse_weights(&list, data, len); +#ifndef DISABLE_LACE + if (ret == 0) {ret = init_lace(&model->lace, list);} #endif - osce_reset(hOSCE, method); -#else - (void *) hOSCE; - (void) method; +#ifndef DISABLE_LACE + if (ret == 0) {ret = init_nolace(&model->nolace, list);} #endif -} + free(list); + } else + { #ifdef USE_WEIGHTS_FILE -int osce_load_models(silk_OSCE_struct *hOSCE, const unsigned char *data, int len) -{ - WeightArray *list; - int ret = 0; - + return -1; +#else #ifndef DISABLE_LACE - if (ret == 0) {ret = init_lace(&hOSCE->model.lace, list);} + if (ret == 0) {ret = init_lace(&model->lace, lacelayers_arrays);} #endif #ifndef DISABLE_LACE - if (ret == 0) {ret = init_nolace(&hOSCE->model.nolace, list);} + if (ret == 0) {ret = init_nolace(&model->nolace, nolacelayers_arrays);} #endif - osce_reset(hOSCE, OSCE_DEFAULT_METHOD); - - free(list); +#endif /* USE_WEIGHTS_FILE */ + } + ret = ret ? -1 : 0; return ret; } -#endif void osce_enhance_frame( + OSCEModel *model, /* I OSCE model struct */ silk_decoder_state *psDec, /* I/O Decoder state */ silk_decoder_control *psDecCtrl, /* I Decoder control */ opus_int16 xq[], /* I/O Decoded speech */ @@ -894,12 +890,12 @@ void osce_enhance_frame( break; #ifndef DISABLE_LACE case OSCE_METHOD_LACE: - lace_process_20ms_frame(&psDec->osce.model.lace, &psDec->osce.state.lace, out_buffer, in_buffer, features, numbits, periods, arch); + lace_process_20ms_frame(&model->lace, &psDec->osce.state.lace, out_buffer, in_buffer, features, numbits, periods, arch); break; #endif #ifndef DISABLE_NOLACE case OSCE_METHOD_NOLACE: - nolace_process_20ms_frame(&psDec->osce.model.nolace, &psDec->osce.state.nolace, out_buffer, in_buffer, features, numbits, periods, arch); + nolace_process_20ms_frame(&model->nolace, &psDec->osce.state.nolace, out_buffer, in_buffer, features, numbits, periods, arch); break; #endif default: @@ -65,6 +65,7 @@ void osce_enhance_frame( + OSCEModel *model, /* I OSCE model struct */ silk_decoder_state *psDec, /* I/O Decoder state */ silk_decoder_control *psDecCtrl, /* I Decoder control */ opus_int16 xq[], /* I/O Decoded speech */ @@ -73,8 +74,7 @@ void osce_enhance_frame( ); -void osce_init(silk_OSCE_struct *hOSCE, int method); -int osce_load_models(silk_OSCE_struct *hOSCE, const unsigned char *data, int len); +int osce_load_models(OSCEModel *hModel, const unsigned char *data, int len); void osce_reset(silk_OSCE_struct *hOSCE, int method); diff --git a/silk/dec_API.c b/silk/dec_API.c index 134b3c38..34ceb676 100644 --- a/silk/dec_API.c +++ b/silk/dec_API.c @@ -35,6 +35,7 @@ POSSIBILITY OF SUCH DAMAGE. #ifdef ENABLE_OSCE #include "osce.h" +#include "osce_structs.h" #endif /************************/ @@ -46,6 +47,9 @@ typedef struct { opus_int nChannelsAPI; opus_int nChannelsInternal; opus_int prev_decode_only_middle; +#ifdef ENABLE_OSCE + OSCEModel osce_model; +#endif } silk_decoder; /*********************/ @@ -56,14 +60,10 @@ typedef struct { opus_int silk_LoadOSCEModels(void *decState, const unsigned char *data, int len) { -#if defined(ENABLE_OSCE) && defined(USE_WEIGHTS_FILE) - opus_int n, ret = SILK_NO_ERROR; - - silk_decoder_state *channel_state = ((silk_decoder *)decState)->channel_state; +#ifdef ENABLE_OSCE + opus_int ret = SILK_NO_ERROR; - for ( n = 0; n < DECODER_NUM_CHANNELS; n++ ) { - ret |= osce_load_models(&channel_state[n].osce, data, len); - } + ret = osce_load_models(&((silk_decoder *)decState)->osce_model, data, len); return ret; #else @@ -111,6 +111,11 @@ opus_int silk_InitDecoder( /* O Returns error co opus_int n, ret = SILK_NO_ERROR; silk_decoder_state *channel_state = ((silk_decoder *)decState)->channel_state; +#ifndef USE_WEIGHTS_FILE + /* load osce models */ + silk_LoadOSCEModels(decState, NULL, 0); +#endif + for( n = 0; n < DECODER_NUM_CHANNELS; n++ ) { ret = silk_init_decoder( &channel_state[ n ] ); } @@ -352,6 +357,9 @@ opus_int silk_Decode( /* O Returns error co #ifdef ENABLE_DEEP_PLC n == 0 ? lpcnet : NULL, #endif +#ifdef ENABLE_OSCE + &psDec->osce_model, +#endif arch); } else { silk_memset( &samplesOut1_tmp[ n ][ 2 ], 0, nSamplesOutDec * sizeof( opus_int16 ) ); diff --git a/silk/decode_frame.c b/silk/decode_frame.c index 3f35cc5b..48f74aef 100644 --- a/silk/decode_frame.c +++ b/silk/decode_frame.c @@ -50,6 +50,9 @@ opus_int silk_decode_frame( #ifdef ENABLE_DEEP_PLC LPCNetPLCState *lpcnet, #endif +#ifdef ENABLE_OSCE + OSCEModel *osce_model, +#endif int arch /* I Run-time architecture */ ) { @@ -109,7 +112,7 @@ opus_int silk_decode_frame( /********************************************************/ /* Run SILK enhancer */ /********************************************************/ - osce_enhance_frame( psDec, psDecCtrl, pOut, ec_tell(psRangeDec) - ec_start, arch ); + osce_enhance_frame( osce_model, psDec, psDecCtrl, pOut, ec_tell(psRangeDec) - ec_start, arch ); #endif /********************************************************/ diff --git a/silk/init_decoder.c b/silk/init_decoder.c index 72977975..01bc4b7a 100644 --- a/silk/init_decoder.c +++ b/silk/init_decoder.c @@ -75,11 +75,7 @@ opus_int silk_init_decoder( ) { /* Clear the entire encoder state, except anything copied */ -#ifdef ENABLE_OSCE -#ifndef USE_WEIGHTS_FILE - osce_init(&psDec->osce, OSCE_DEFAULT_METHOD); -#endif -#endif + silk_memset( psDec, 0, sizeof( silk_decoder_state ) ); silk_reset_decoder( psDec ); diff --git a/silk/main.h b/silk/main.h index d5cb2a6e..cd576d8c 100644 --- a/silk/main.h +++ b/silk/main.h @@ -417,6 +417,9 @@ opus_int silk_decode_frame( #ifdef ENABLE_DEEP_PLC LPCNetPLCState *lpcnet, #endif +#ifdef ENABLE_OSCE + OSCEModel *osce_model, +#endif int arch /* I Run-time architecture */ ); diff --git a/silk/structs.h b/silk/structs.h index 123b384f..38243be1 100644 --- a/silk/structs.h +++ b/silk/structs.h @@ -246,7 +246,6 @@ typedef struct { #ifdef ENABLE_OSCE typedef struct { OSCEFeatureState features; - OSCEModel model; OSCEState state; int method; } silk_OSCE_struct; diff --git a/src/opus_decoder.c b/src/opus_decoder.c index c4fe60ef..eba0010e 100644 --- a/src/opus_decoder.c +++ b/src/opus_decoder.c @@ -1057,7 +1057,7 @@ int opus_decoder_ctl(OpusDecoder *st, int request, ...) goto bad_arg; } ret = lpcnet_plc_load_model(&st->lpcnet, data, len); - ret |= osce_load_models(silk_dec, data, len); + ret |= silk_LoadOSCEModels(silk_dec, data, len); } break; #endif |