diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-12-11 13:13:11 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-12-11 13:13:11 +0300 |
commit | 241781d4c0defd6214e2baef9a3376eb6d77d8a7 (patch) | |
tree | 2bc4c6bde578ea0370eaf57975c2fced5731006d | |
parent | 9de91f631c2147943ac70679d1f37239fb136a53 (diff) | |
parent | d2f07aeb10636596d25990cfc1eb7b27f891e7bd (diff) |
Merge remote-tracking branch 'origin/opus-ng-lace-integration2' into opus-ng-lace-integration2-model-tuningopus-ng-lace-integration2-model-tuning
-rw-r--r-- | dnn/osce.c | 2 | ||||
-rw-r--r-- | dnn/osce.h | 6 | ||||
-rw-r--r-- | dnn/osce_structs.h | 24 | ||||
-rw-r--r-- | dnn/torch/osce/export_model_weights.py | 21 | ||||
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/torch/torch.py | 52 | ||||
-rw-r--r-- | silk/API.h | 6 | ||||
-rw-r--r-- | silk/control.h | 5 | ||||
-rw-r--r-- | silk/dec_API.c | 25 | ||||
-rw-r--r-- | silk/decode_frame.c | 4 | ||||
-rw-r--r-- | silk/init_decoder.c | 29 | ||||
-rw-r--r-- | silk/main.h | 4 | ||||
-rw-r--r-- | silk/structs.h | 8 | ||||
-rw-r--r-- | src/opus_decoder.c | 17 |
13 files changed, 167 insertions, 36 deletions
@@ -833,7 +833,7 @@ void osce_enhance_frame( /* enhancement only implemented for 20 ms frame at 16kHz */ if (psDec->fs_kHz != 16 || psDec->nb_subfr != 4) { - /* Question: reset state? */ + osce_reset(&psDec->osce, psDec->osce.method); return; } @@ -23,7 +23,13 @@ #define OSCE_METHOD_NOLACE 2 #endif +#if !defined(DISABLE_NOLACE) #define OSCE_DEFAULT_METHOD OSCE_METHOD_NOLACE +#elif !defined(DISABLE_LACE) +#define OSCE_DEFAULT_METHOD OSCE_METHOD_LACE +#else +#define OSCE_DEFAULT_METHOD OSCE_METHOD_NONE +#endif diff --git a/dnn/osce_structs.h b/dnn/osce_structs.h index 4f1720dd..67fbf811 100644 --- a/dnn/osce_structs.h +++ b/dnn/osce_structs.h @@ -3,8 +3,12 @@ #include "opus_types.h" #include "osce_config.h" +#ifndef DISABLE_LACE #include "lace_data.h" +#endif +#ifndef DISABLE_NOLACE #include "nolace_data.h" +#endif #include "nndsp.h" #include "nnet.h" @@ -18,8 +22,9 @@ typedef struct { float signal_history[OSCE_FEATURES_MAX_HISTORY]; } OSCEFeatureState; -/* LACE */ +#ifndef DISABLE_LACE +/* LACE */ typedef struct { float feature_net_conv2_state[LACE_FEATURE_NET_CONV2_STATE_SIZE]; float feature_net_gru_state[LACE_COND_DIM]; /* ToDo: fix! */ @@ -33,12 +38,14 @@ typedef struct { typedef struct { LACELayers layers; - LACEState state; float window[LACE_OVERLAP_SIZE]; } LACE; -/* NoLACE */ +#endif /* #ifndef DISABLE_LACE */ + +#ifndef DISABLE_NOLACE +/* NoLACE */ typedef struct { float feature_net_conv2_state[NOLACE_FEATURE_NET_CONV2_STATE_SIZE]; float feature_net_gru_state[NOLACE_COND_DIM]; @@ -62,19 +69,28 @@ typedef struct { typedef struct { NOLACELayers layers; - NoLACEState state; float window[LACE_OVERLAP_SIZE]; } NoLACE; +#endif /* #ifndef DISABLE_NOLACE */ + /* OSCEModel */ typedef struct { +#ifndef DISABLE_LACE LACE lace; +#endif +#ifndef DISABLE_NOLACE NoLACE nolace; +#endif } OSCEModel; typedef union { +#ifndef DISABLE_LACE LACEState lace; +#endif +#ifndef DISABLE_NOLACE NoLACEState nolace; +#endif } OSCEState; #endif
\ No newline at end of file diff --git a/dnn/torch/osce/export_model_weights.py b/dnn/torch/osce/export_model_weights.py index c3b723c7..786d3200 100644 --- a/dnn/torch/osce/export_model_weights.py +++ b/dnn/torch/osce/export_model_weights.py @@ -51,6 +51,7 @@ parser = argparse.ArgumentParser() parser.add_argument('checkpoint', type=str, help='LACE or NoLACE model checkpoint') parser.add_argument('output_dir', type=str, help='output folder') +parser.add_argument('--quantize', action="store_true", help='quantization according to schedule') schedules = { @@ -60,15 +61,15 @@ schedules = { ('feature_net.conv2', dict(quantize=True, scale=None)), ('feature_net.tconv', dict(quantize=True, scale=None)), ('feature_net.gru', dict()), - ('cf1', dict()), - ('cf2', dict()), - ('af1', dict()), + ('cf1', dict(quantize=True, scale=None)), + ('cf2', dict(quantize=True, scale=None)), + ('af1', dict(quantize=True, scale=None)), ('tdshape1', dict()), ('tdshape2', dict()), ('tdshape3', dict()), - ('af2', dict()), - ('af3', dict()), - ('af4', dict()), + ('af2', dict(quantize=True, scale=None)), + ('af3', dict(quantize=True, scale=None)), + ('af4', dict(quantize=True, scale=None)), ('post_cf1', dict(quantize=True, scale=None)), ('post_cf2', dict(quantize=True, scale=None)), ('post_af1', dict(quantize=True, scale=None)), @@ -81,9 +82,9 @@ schedules = { ('feature_net.conv2', dict(quantize=True, scale=None)), ('feature_net.tconv', dict(quantize=True, scale=None)), ('feature_net.gru', dict()), - ('cf1', dict()), - ('cf2', dict()), - ('af1', dict()) + ('cf1', dict(quantize=True, scale=None)), + ('cf2', dict(quantize=True, scale=None)), + ('af1', dict(quantize=True, scale=None)) ] } @@ -161,7 +162,7 @@ if __name__ == "__main__": cwriter.header.write(f"#define {model_name.upper()}_NUMBITS_SCALE_{i} {float(s.detach().cpu())}\n") # dump layers - if model_name in schedules: + if model_name in schedules and args.quantize: osce_scheduled_dump(cwriter, model_name, model, schedules[model_name]) else: osce_dump_generic(cwriter, model_name, model) diff --git a/dnn/torch/weight-exchange/wexchange/torch/torch.py b/dnn/torch/weight-exchange/wexchange/torch/torch.py index 2dcee1c5..00a1a4bb 100644 --- a/dnn/torch/weight-exchange/wexchange/torch/torch.py +++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py @@ -45,7 +45,7 @@ except: from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer, print_tconv1d_layer, print_conv2d_layer -def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', kernel_scale=1/128, kernel_quantize=False, gain_scale=1/128, gain_quantize=False): +def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', scale=1/128, quantize=False): w_kernel = adaconv.conv_kernel.weight.detach().cpu().numpy().copy() @@ -54,14 +54,34 @@ def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', kernel_sc b_gain = adaconv.filter_gain.bias.detach().cpu().numpy().copy() if isinstance(where, CWriter): + # pad kernel for quantization + left_padding = adaconv.padding[0] + kernel_size = adaconv.kernel_size + in_channels = adaconv.in_channels + out_channels = adaconv.out_channels + feature_dim = adaconv.feature_dim + + if quantize and kernel_size % 8: + kernel_padding = 8 - (kernel_size % 8) + w_kernel = np.concatenate( + (np.zeros((out_channels, in_channels, kernel_padding, feature_dim)), w_kernel.reshape(out_channels, in_channels, kernel_size, feature_dim)), + dtype=w_kernel.dtype, + axis=2).reshape(-1, feature_dim) + b_kernel = np.concatenate( + (np.zeros((out_channels, in_channels, kernel_padding)), b_kernel.reshape(out_channels, in_channels, kernel_size)), + dtype=b_kernel.dtype, + axis=2).reshape(-1) + left_padding += kernel_padding + kernel_size += kernel_padding + # write relevant scalar parameters to header file where.header.write(f""" #define {name.upper()}_FILTER_GAIN_A {adaconv.filter_gain_a:f}f #define {name.upper()}_FILTER_GAIN_B {adaconv.filter_gain_b:f}f #define {name.upper()}_SHAPE_GAIN {adaconv.shape_gain:f}f -#define {name.upper()}_KERNEL_SIZE {adaconv.kernel_size} +#define {name.upper()}_KERNEL_SIZE {kernel_size} #define {name.upper()}_FRAME_SIZE {adaconv.frame_size} -#define {name.upper()}_LEFT_PADDING {adaconv.padding[0]} +#define {name.upper()}_LEFT_PADDING {left_padding} #define {name.upper()}_OVERLAP_SIZE {adaconv.overlap_size} #define {name.upper()}_IN_CHANNELS {adaconv.in_channels} #define {name.upper()}_OUT_CHANNELS {adaconv.out_channels} @@ -70,8 +90,8 @@ def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', kernel_sc """ ) - print_dense_layer(where, name + "_kernel", w_kernel, b_kernel, scale=kernel_scale, format='torch', sparse=False, diagonal=False, quantize=kernel_quantize) - print_dense_layer(where, name + "_gain", w_gain, b_gain, scale=gain_scale, format='torch', sparse=False, diagonal=False, quantize=gain_quantize) + print_dense_layer(where, name + "_kernel", w_kernel, b_kernel, scale=scale, format='torch', sparse=False, diagonal=False, quantize=quantize) + print_dense_layer(where, name + "_gain", w_gain, b_gain, format='torch', sparse=False, diagonal=False, quantize=False) else: @@ -81,7 +101,7 @@ def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', kernel_sc np.save(where, 'bias_gain.npy', b_gain) -def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', kernel_scale=1/128, kernel_quantize=False, gain_scale=1/128, gain_quantize=False, global_gain_scale=1/128, global_gain_quantize=False): +def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', scale=1/128, quantize=False): w_kernel = adaconv.conv_kernel.weight.detach().cpu().numpy().copy() @@ -93,13 +113,23 @@ def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', kernel_sc if isinstance(where, CWriter): + # pad kernel for quantization + left_padding = adaconv.padding[0] + kernel_size = adaconv.kernel_size + + if quantize and w_kernel.shape[0] % 8: + kernel_padding = 8 - (w_kernel.shape[0] % 8) + w_kernel = np.concatenate((np.zeros((kernel_padding, w_kernel.shape[1])), w_kernel), dtype=w_kernel.dtype) + b_kernel = np.concatenate((np.zeros((kernel_padding)), b_kernel), dtype=b_kernel.dtype) + left_padding += kernel_padding + kernel_size += kernel_padding # write relevant scalar parameters to header file where.header.write(f""" #define {name.upper()}_FILTER_GAIN_A {adaconv.filter_gain_a:f}f #define {name.upper()}_FILTER_GAIN_B {adaconv.filter_gain_b:f}f #define {name.upper()}_LOG_GAIN_LIMIT {adaconv.log_gain_limit:f}f -#define {name.upper()}_KERNEL_SIZE {adaconv.kernel_size} -#define {name.upper()}_LEFT_PADDING {adaconv.padding[0]} +#define {name.upper()}_KERNEL_SIZE {kernel_size} +#define {name.upper()}_LEFT_PADDING {left_padding} #define {name.upper()}_FRAME_SIZE {adaconv.frame_size} #define {name.upper()}_OVERLAP_SIZE {adaconv.overlap_size} #define {name.upper()}_IN_CHANNELS {adaconv.in_channels} @@ -110,9 +140,9 @@ def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', kernel_sc """ ) - print_dense_layer(where, name + "_kernel", w_kernel, b_kernel, scale=kernel_scale, format='torch', sparse=False, diagonal=False, quantize=kernel_quantize) - print_dense_layer(where, name + "_gain", w_gain, b_gain, scale=gain_scale, format='torch', sparse=False, diagonal=False, quantize=gain_quantize) - print_dense_layer(where, name + "_global_gain", w_global_gain, b_global_gain, scale=global_gain_scale, format='torch', sparse=False, diagonal=False, quantize=global_gain_quantize) + print_dense_layer(where, name + "_kernel", w_kernel, b_kernel, scale=scale, format='torch', sparse=False, diagonal=False, quantize=quantize) + print_dense_layer(where, name + "_gain", w_gain, b_gain, format='torch', sparse=False, diagonal=False, quantize=False) + print_dense_layer(where, name + "_global_gain", w_global_gain, b_global_gain, format='torch', sparse=False, diagonal=False, quantize=False) else: @@ -100,8 +100,12 @@ opus_int silk_Get_Decoder_Size( /* O Returns error co ); /*************************/ -/* Init or Reset decoder */ +/* Init and Reset decoder */ /*************************/ +opus_int silk_ResetDecoder( /* O Returns error code */ + void *decState /* I/O State */ +); + opus_int silk_InitDecoder( /* O Returns error code */ void *decState /* I/O State */ ); diff --git a/silk/control.h b/silk/control.h index d30d114c..f5633e62 100644 --- a/silk/control.h +++ b/silk/control.h @@ -147,6 +147,11 @@ typedef struct { /* I: Enable Deep PLC */ opus_int enable_deep_plc; + +#ifdef ENABLE_OSCE + /* I: OSCE method */ + opus_int osce_method; +#endif } silk_DecControlStruct; #ifdef __cplusplus diff --git a/silk/dec_API.c b/silk/dec_API.c index a29ecc73..81c433d8 100644 --- a/silk/dec_API.c +++ b/silk/dec_API.c @@ -33,6 +33,10 @@ POSSIBILITY OF SUCH DAMAGE. #include "stack_alloc.h" #include "os_support.h" +#ifdef ENABLE_OSCE +#include "osce.h" +#endif + /************************/ /* Decoder Super Struct */ /************************/ @@ -60,6 +64,24 @@ opus_int silk_Get_Decoder_Size( /* O Returns error co } /* Reset decoder state */ +opus_int silk_ResetDecoder( /* O Returns error code */ + void *decState /* I/O State */ +) +{ + opus_int n, ret = SILK_NO_ERROR; + silk_decoder_state *channel_state = ((silk_decoder *)decState)->channel_state; + + for( n = 0; n < DECODER_NUM_CHANNELS; n++ ) { + ret = silk_reset_decoder( &channel_state[ n ] ); + } + silk_memset(&((silk_decoder *)decState)->sStereo, 0, sizeof(((silk_decoder *)decState)->sStereo)); + /* Not strictly needed, but it's cleaner that way */ + ((silk_decoder *)decState)->prev_decode_only_middle = 0; + + return ret; +} + + opus_int silk_InitDecoder( /* O Returns error code */ void *decState /* I/O State */ ) @@ -301,6 +323,9 @@ opus_int silk_Decode( /* O Returns error co } else { condCoding = CODE_CONDITIONALLY; } +#ifdef ENABLE_OSCE + if (channel_state[n].osce.method != decControl->osce_method) {osce_reset(&channel_state[n].osce, decControl->osce_method);} +#endif ret += silk_decode_frame( &channel_state[ n ], psRangeDec, &samplesOut1_tmp[ n ][ 2 ], &nSamplesOutDec, lostFlag, condCoding, #ifdef ENABLE_DEEP_PLC n == 0 ? lpcnet : NULL, diff --git a/silk/decode_frame.c b/silk/decode_frame.c index 80283a38..d35aac08 100644 --- a/silk/decode_frame.c +++ b/silk/decode_frame.c @@ -140,6 +140,10 @@ opus_int silk_decode_frame( /********************************************************/ osce_enhance_frame( psDec, psDecCtrl, pOut, ec_tell(psRangeDec) - ec_start, arch ); } + else + { + osce_reset( &psDec->osce, psDec->osce.method ); + } #endif /************************************************/ diff --git a/silk/init_decoder.c b/silk/init_decoder.c index 8e873c60..2dbda429 100644 --- a/silk/init_decoder.c +++ b/silk/init_decoder.c @@ -35,15 +35,17 @@ POSSIBILITY OF SUCH DAMAGE. #include "osce.h" #endif +#include "struct.h" + /************************/ -/* Init Decoder State */ +/* Reset Decoder State */ /************************/ -opus_int silk_init_decoder( +opus_int silk_reset_decoder( silk_decoder_state *psDec /* I/O Decoder state pointer */ ) { /* Clear the entire encoder state, except anything copied */ - silk_memset( psDec, 0, sizeof( silk_decoder_state ) ); + silk_memset( &psDec->SILK_DECODER_STATE_RESET_START, 0, sizeof( silk_decoder_state ) - ((char*) &psDec->SILK_DECODER_STATE_RESET_START - (char*)psDec) ); /* Used to deactivate LSF interpolation */ psDec->first_frame_after_reset = 1; @@ -57,9 +59,30 @@ opus_int silk_init_decoder( silk_PLC_Reset( psDec ); #ifdef ENABLE_OSCE + /* Reset OSCE state and method */ + osce_reset(&psDec->osce, OSCE_DEFAULT_METHOD); +#endif + + return 0; +} + + +/************************/ +/* Init Decoder State */ +/************************/ +opus_int silk_init_decoder( + silk_decoder_state *psDec /* I/O Decoder state pointer */ +) +{ + /* Clear the entire encoder state, except anything copied */ + silk_memset( psDec, 0, sizeof( silk_decoder_state ) ); + +#ifdef ENABLE_OSCE osce_init(&psDec->osce, OSCE_DEFAULT_METHOD, NULL); #endif + silk_reset_decoder( psDec ); + return(0); } diff --git a/silk/main.h b/silk/main.h index c67775ef..d5cb2a6e 100644 --- a/silk/main.h +++ b/silk/main.h @@ -389,6 +389,10 @@ void silk_NLSF_decode( /****************************************************/ /* Decoder Functions */ /****************************************************/ +opus_int silk_reset_decoder( + silk_decoder_state *psDec /* I/O Decoder state pointer */ +); + opus_int silk_init_decoder( silk_decoder_state *psDec /* I/O Decoder state pointer */ ); diff --git a/silk/structs.h b/silk/structs.h index 5f912edf..123b384f 100644 --- a/silk/structs.h +++ b/silk/structs.h @@ -284,6 +284,10 @@ typedef struct { /* Decoder state */ /********************************/ typedef struct { +#ifdef ENABLE_OSCE + silk_OSCE_struct osce; +#endif +#define SILK_DECODER_STATE_RESET_START prev_gain_Q16 opus_int32 prev_gain_Q16; opus_int32 exc_Q14[ MAX_FRAME_LENGTH ]; opus_int32 sLPC_Q14_buf[ MAX_LPC_ORDER ]; @@ -324,10 +328,6 @@ typedef struct { /* CNG state */ silk_CNG_struct sCNG; -#ifdef ENABLE_OSCE - silk_OSCE_struct osce; -#endif - /* Stuff used for PLC */ opus_int lossCnt; opus_int prevSignalType; diff --git a/src/opus_decoder.c b/src/opus_decoder.c index 1e0a1da4..44a183d4 100644 --- a/src/opus_decoder.c +++ b/src/opus_decoder.c @@ -57,6 +57,10 @@ #include "dred_rdovae_dec.h" #endif +#ifdef ENABLE_OSCE +#include "osce.h" +#endif + struct OpusDecoder { int celt_dec_offset; int silk_dec_offset; @@ -383,7 +387,7 @@ static int opus_decode_frame(OpusDecoder *st, const unsigned char *data, pcm_ptr = pcm_silk; if (st->prev_mode==MODE_CELT_ONLY) - silk_InitDecoder( silk_dec ); + silk_ResetDecoder( silk_dec ); /* The SILK PLC cannot produce frames of less than 10 ms */ st->DecControl.payloadSize_ms = IMAX(10, 1000 * audiosize / st->Fs); @@ -408,6 +412,15 @@ static int opus_decode_frame(OpusDecoder *st, const unsigned char *data, } } st->DecControl.enable_deep_plc = st->complexity >= 5; +#ifdef ENABLE_OSCE + st->DecControl.osce_method = OSCE_METHOD_NONE; +#ifndef DISABLE_LACE + if (st->complexity >= 2) {st->DecControl.osce_method = OSCE_METHOD_LACE;} +#endif +#ifndef DISABLE_NOLACE + if (st->complexity >= 5) {st->DecControl.osce_method = OSCE_METHOD_NOLACE;} +#endif +#endif lost_flag = data == NULL ? 1 : 2 * !!decode_fec; decoded_samples = 0; @@ -953,7 +966,7 @@ int opus_decoder_ctl(OpusDecoder *st, int request, ...) ((char*)&st->OPUS_DECODER_RESET_START - (char*)st)); celt_decoder_ctl(celt_dec, OPUS_RESET_STATE); - silk_InitDecoder( silk_dec ); + silk_ResetDecoder( silk_dec ); st->stream_channels = st->channels; st->frame_size = st->Fs/400; #ifdef ENABLE_DEEP_PLC |