From 35d82b320a4ea3c78bf68b9901eb0f019bf7bb8b Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Fri, 4 Aug 2023 16:17:24 -0400 Subject: Use FWGAN instead of LPCNet in PLC --- dnn/lpcnet_plc.c | 102 ++++++--------------------------------------------- dnn/lpcnet_private.h | 7 ++-- 2 files changed, 15 insertions(+), 94 deletions(-) diff --git a/dnn/lpcnet_plc.c b/dnn/lpcnet_plc.c index 879064af..31bbcdc5 100644 --- a/dnn/lpcnet_plc.c +++ b/dnn/lpcnet_plc.c @@ -49,18 +49,15 @@ void lpcnet_plc_reset(LPCNetPLCState *st) { OPUS_CLEAR((char*)&st->LPCNET_PLC_RESET_START, sizeof(LPCNetPLCState)- ((char*)&st->LPCNET_PLC_RESET_START - (char*)st)); - lpcnet_reset(&st->lpcnet); lpcnet_encoder_init(&st->enc); OPUS_CLEAR(st->pcm, PLC_BUF_SIZE); - st->pcm_fill = PLC_BUF_SIZE; - st->skip_analysis = 0; st->blend = 0; st->loss_count = 0; } int lpcnet_plc_init(LPCNetPLCState *st, int options) { int ret; - lpcnet_init(&st->lpcnet); + fwgan_init(&st->fwgan); lpcnet_encoder_init(&st->enc); if ((options&0x3) == LPCNET_PLC_CAUSAL) { st->enable_blending = 1; @@ -86,7 +83,7 @@ int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len ret = init_plc_model(&st->model, list); free(list); if (ret == 0) { - return lpcnet_load_model(&st->lpcnet, data, len); + return fwgan_load_model(&st->fwgan, data, len); } else return -1; } @@ -166,65 +163,18 @@ static void fec_rewind(LPCNetPLCState *st, int offset) { } } -void clear_state(LPCNetPLCState *st) { - OPUS_CLEAR(st->lpcnet.last_sig, LPC_ORDER); - st->lpcnet.last_exc = lin2ulaw(0.f); - st->lpcnet.deemph_mem = 0; - OPUS_CLEAR(st->lpcnet.nnet.gru_a_state, GRU_A_STATE_SIZE); - OPUS_CLEAR(st->lpcnet.nnet.gru_b_state, GRU_B_STATE_SIZE); -} - /* In this causal version of the code, the DNN model implemented by compute_plc_pred() needs to generate two feature vectors to conceal the first lost packet.*/ int lpcnet_plc_update(LPCNetPLCState *st, opus_int16 *pcm) { int i; float x[FRAME_SIZE]; - opus_int16 output[FRAME_SIZE]; float plc_features[2*NB_BANDS+NB_FEATURES+1]; - int delta = 0; for (i=0;iskip_analysis) { - /*fprintf(stderr, "skip update\n");*/ - if (st->blend) { - opus_int16 tmp[FRAME_SIZE-TRAINING_OFFSET]; - float zeros[2*NB_BANDS+NB_FEATURES+1] = {0}; - OPUS_COPY(zeros, plc_features, 2*NB_BANDS); - zeros[2*NB_BANDS+NB_FEATURES] = 1; - if (st->enable_blending) { - LPCNetState copy; - st->plc_net = st->plc_copy[FEATURES_DELAY]; - compute_plc_pred(st, st->features, zeros); - for (i=0;ilpcnet, st->features); - } - copy = st->lpcnet; - lpcnet_synthesize_impl(&st->lpcnet, &st->features[0], tmp, FRAME_SIZE-TRAINING_OFFSET, 0); - for (i=0;ilpcnet = copy; - /*lpcnet_synthesize_impl(&st->lpcnet, &st->features[0], pcm, FRAME_SIZE-TRAINING_OFFSET, FRAME_SIZE-TRAINING_OFFSET);*/ - } else { - if (FEATURES_DELAY > 0) st->plc_net = st->plc_copy[FEATURES_DELAY-1]; - fec_rewind(st, FEATURES_DELAY); -#ifdef PLC_SKIP_UPDATES - lpcnet_reset_signal(&st->lpcnet); -#else - OPUS_COPY(tmp, pcm, FRAME_SIZE-TRAINING_OFFSET); - lpcnet_synthesize_tail_impl(&st->lpcnet, tmp, FRAME_SIZE-TRAINING_OFFSET, FRAME_SIZE-TRAINING_OFFSET); -#endif - } - OPUS_COPY(st->pcm, pcm, FRAME_SIZE); - st->pcm_fill = FRAME_SIZE; - } else { - OPUS_COPY(&st->pcm[st->pcm_fill], pcm, FRAME_SIZE); - st->pcm_fill += FRAME_SIZE; - } + if (st->blend) { + if (FEATURES_DELAY > 0) st->plc_net = st->plc_copy[FEATURES_DELAY-1]; + fec_rewind(st, FEATURES_DELAY); } /* Update state. */ /*fprintf(stderr, "update state\n");*/ @@ -241,24 +191,8 @@ int lpcnet_plc_update(LPCNetPLCState *st, opus_int16 *pcm) { else if (st->fec_read_pos < st->fec_fill_pos) st->fec_read_pos++; st->fec_keep_pos = IMAX(0, IMAX(st->fec_keep_pos, st->fec_read_pos-FEATURES_DELAY-1)); } - if (st->skip_analysis) { - if (st->enable_blending) { - /* FIXME: backtrack state, replace features. */ - run_frame_network_deferred(&st->lpcnet, st->enc.features); - } - st->skip_analysis--; - } else { - for (i=0;ipcm[PLC_BUF_SIZE+i] = pcm[i]; - OPUS_COPY(output, &st->pcm[0], FRAME_SIZE); -#ifdef PLC_SKIP_UPDATES - { - run_frame_network_deferred(&st->lpcnet, st->enc.features); - } -#else - lpcnet_synthesize_impl(&st->lpcnet, st->enc.features, output, FRAME_SIZE, FRAME_SIZE); -#endif - OPUS_MOVE(st->pcm, &st->pcm[FRAME_SIZE], PLC_BUF_SIZE); - } + OPUS_MOVE(st->pcm, &st->pcm[FRAME_SIZE], FWGAN_CONT_SAMPLES-FRAME_SIZE); + for (i=0;ipcm[FWGAN_CONT_SAMPLES-FRAME_SIZE+i] = (1.f/32768.f)*pcm[i]; st->loss_count = 0; st->blend = 0; return 0; @@ -267,31 +201,17 @@ int lpcnet_plc_update(LPCNetPLCState *st, opus_int16 *pcm) { static const float att_table[10] = {0, 0, -.2, -.2, -.4, -.4, -.8, -.8, -1.6, -1.6}; int lpcnet_plc_conceal(LPCNetPLCState *st, opus_int16 *pcm) { int i; - opus_int16 output[FRAME_SIZE]; - run_frame_network_flush(&st->lpcnet); - /* If we concealed the previous frame, finish synthesizing the rest of the samples. */ - /* FIXME: Copy/predict features. */ - while (st->pcm_fill > 0) { - /*fprintf(stderr, "update state for PLC %d\n", st->pcm_fill);*/ - int update_count; - update_count = IMIN(st->pcm_fill, FRAME_SIZE); - OPUS_COPY(output, &st->pcm[0], update_count); - OPUS_MOVE(&st->plc_copy[1], &st->plc_copy[0], FEATURES_DELAY); - st->plc_copy[0] = st->plc_net; + if (st->blend == 0) { get_fec_or_pred(st, st->features); - lpcnet_synthesize_impl(&st->lpcnet, &st->features[0], output, update_count, update_count); - OPUS_MOVE(st->pcm, &st->pcm[FRAME_SIZE], PLC_BUF_SIZE); - st->pcm_fill -= update_count; - st->skip_analysis++; + fwgan_cont(&st->fwgan, st->pcm, &st->features[0]); } OPUS_MOVE(&st->plc_copy[1], &st->plc_copy[0], FEATURES_DELAY); st->plc_copy[0] = st->plc_net; - /*lpcnet_synthesize_tail_impl(&st->lpcnet, pcm, FRAME_SIZE-TRAINING_OFFSET, 0);*/ if (get_fec_or_pred(st, st->features)) st->loss_count = 0; else st->loss_count++; if (st->loss_count >= 10) st->features[0] = MAX16(-10, st->features[0]+att_table[9] - 2*(st->loss_count-9)); else st->features[0] = MAX16(-10, st->features[0]+att_table[st->loss_count]); - lpcnet_synthesize_impl(&st->lpcnet, &st->features[0], pcm, FRAME_SIZE, 0); + fwgan_synthesize_int(&st->fwgan, pcm, &st->features[0]); { float x[FRAME_SIZE]; /* FIXME: Can we do better? */ @@ -300,6 +220,8 @@ int lpcnet_plc_conceal(LPCNetPLCState *st, opus_int16 *pcm) { compute_frame_features(&st->enc, x); process_single_frame(&st->enc, NULL); } + OPUS_MOVE(st->pcm, &st->pcm[FRAME_SIZE], FWGAN_CONT_SAMPLES-FRAME_SIZE); + for (i=0;ipcm[FWGAN_CONT_SAMPLES-FRAME_SIZE+i] = (1.f/32768.f)*pcm[i]; st->blend = 1; return 0; } diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h index da048e7e..7d95c02e 100644 --- a/dnn/lpcnet_private.h +++ b/dnn/lpcnet_private.h @@ -7,6 +7,7 @@ #include "nnet_data.h" #include "plc_data.h" #include "kiss99.h" +#include "fwgan.h" #define PITCH_MIN_PERIOD 32 #define PITCH_MAX_PERIOD 256 @@ -65,7 +66,7 @@ struct LPCNetEncState{ #define PLC_BUF_SIZE (FEATURES_DELAY*FRAME_SIZE + FRAME_SIZE) struct LPCNetPLCState { PLCModel model; - LPCNetState lpcnet; + FWGANState fwgan; LPCNetEncState enc; int arch; int enable_blending; @@ -76,9 +77,7 @@ struct LPCNetPLCState { int fec_read_pos; int fec_fill_pos; int fec_skip; - opus_int16 pcm[PLC_BUF_SIZE+FRAME_SIZE]; - int pcm_fill; - int skip_analysis; + float pcm[FWGAN_CONT_SAMPLES]; int blend; float features[NB_TOTAL_FEATURES]; int loss_count; -- cgit v1.2.3