Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-10-30 21:08:07 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-10-30 21:08:07 +0300
commitfeb32828877ea5e8723ea2a446eb20d7b3fba426 (patch)
treec008f01b3c91bd5bc82caf0f93ff5cdef7283f24
parent62b546436fc07035802eb998f61702ee2716db60 (diff)
Don't try to use models that aren't loaded
-rw-r--r--celt/celt_decoder.c4
-rw-r--r--dnn/lpcnet_plc.c11
-rw-r--r--dnn/lpcnet_private.h1
-rw-r--r--silk/PLC.c2
-rw-r--r--silk/dred_encoder.c6
-rw-r--r--silk/dred_encoder.h5
-rw-r--r--src/opus_decoder.c13
-rw-r--r--src/opus_encoder.c26
8 files changed, 43 insertions, 25 deletions
diff --git a/celt/celt_decoder.c b/celt/celt_decoder.c
index c0c997f0..ac1b4fed 100644
--- a/celt/celt_decoder.c
+++ b/celt/celt_decoder.c
@@ -721,7 +721,7 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM
if (loss_duration == 0)
{
#ifdef ENABLE_DEEP_PLC
- update_plc_state(lpcnet, decode_mem, &st->plc_preemphasis_mem, C);
+ if (lpcnet->loaded) update_plc_state(lpcnet, decode_mem, &st->plc_preemphasis_mem, C);
#endif
st->last_pitch_index = pitch_index = celt_plc_pitch_search(decode_mem, C, st->arch);
} else {
@@ -914,7 +914,7 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM
} while (++c<C);
#ifdef ENABLE_DEEP_PLC
- if (st->complexity >= 5 || lpcnet->fec_fill_pos > 0) {
+ if (lpcnet->loaded && (st->complexity >= 5 || lpcnet->fec_fill_pos > 0)) {
float overlap_mem;
int samples_needed16k;
celt_sig *buf;
diff --git a/dnn/lpcnet_plc.c b/dnn/lpcnet_plc.c
index a124cdd0..de3ab1a7 100644
--- a/dnn/lpcnet_plc.c
+++ b/dnn/lpcnet_plc.c
@@ -57,8 +57,10 @@ int lpcnet_plc_init(LPCNetPLCState *st) {
fargan_init(&st->fargan);
lpcnet_encoder_init(&st->enc);
st->analysis_pos = PLC_BUF_SIZE;
+ st->loaded = 0;
#ifndef USE_WEIGHTS_FILE
ret = init_plc_model(&st->model, lpcnet_plc_arrays);
+ if (ret == 0) st->loaded = 1;
#else
ret = 0;
#endif
@@ -75,11 +77,12 @@ int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len
free(list);
if (ret == 0) {
ret = lpcnet_encoder_load_model(&st->enc, data, len);
- } else return -1;
+ }
if (ret == 0) {
- return fargan_load_model(&st->fargan, data, len);
+ ret = fargan_load_model(&st->fargan, data, len);
}
- else return -1;
+ if (ret == 0) st->loaded = 1;
+ return ret;
}
void lpcnet_plc_fec_add(LPCNetPLCState *st, const float *features) {
@@ -105,6 +108,7 @@ static void compute_plc_pred(LPCNetPLCState *st, float *out, const float *in) {
float zeros[3*PLC_MAX_RNN_NEURONS] = {0};
float dense_out[PLC_DENSE1_OUT_SIZE];
PLCNetState *net = &st->plc_net;
+ celt_assert(st->loaded);
_lpcnet_compute_dense(&st->model.plc_dense1, dense_out, in);
compute_gruB(&st->model.plc_gru1, zeros, net->plc_gru1_state, dense_out);
compute_gruB(&st->model.plc_gru2, zeros, net->plc_gru2_state, net->plc_gru1_state);
@@ -152,6 +156,7 @@ int lpcnet_plc_update(LPCNetPLCState *st, opus_int16 *pcm) {
static const float att_table[10] = {0, 0, -.2, -.2, -.4, -.4, -.8, -.8, -1.6, -1.6};
int lpcnet_plc_conceal(LPCNetPLCState *st, opus_int16 *pcm) {
int i;
+ celt_assert(st->loaded);
if (st->blend == 0) {
int count = 0;
while (st->analysis_pos + FRAME_SIZE <= PLC_BUF_SIZE) {
diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h
index 4f328ad2..30931b1d 100644
--- a/dnn/lpcnet_private.h
+++ b/dnn/lpcnet_private.h
@@ -47,6 +47,7 @@ struct LPCNetPLCState {
PLCModel model;
FARGANState fargan;
LPCNetEncState enc;
+ int loaded;
int arch;
#define LPCNET_PLC_RESET_START fec
diff --git a/silk/PLC.c b/silk/PLC.c
index 1e524823..b35bf750 100644
--- a/silk/PLC.c
+++ b/silk/PLC.c
@@ -397,7 +397,7 @@ static OPUS_INLINE void silk_PLC_conceal(
frame[ i ] = (opus_int16)silk_SAT16( silk_SAT16( silk_RSHIFT_ROUND( silk_SMULWW( sLPC_Q14_ptr[ MAX_LPC_ORDER + i ], prevGain_Q10[ 1 ] ), 8 ) ) );
}
#ifdef ENABLE_DEEP_PLC
- if ( lpcnet != NULL && psDec->sPLC.fs_kHz == 16 ) {
+ if ( lpcnet != NULL && lpcnet->loaded && psDec->sPLC.fs_kHz == 16 ) {
int run_deep_plc = psDec->sPLC.enable_deep_plc || lpcnet->fec_fill_pos != 0;
if( run_deep_plc ) {
for( k = 0; k < psDec->nb_subfr; k += 2 ) {
diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c
index af7f9d94..9b005a63 100644
--- a/silk/dred_encoder.c
+++ b/silk/dred_encoder.c
@@ -57,6 +57,7 @@ int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len)
if (ret == 0) {
ret = lpcnet_encoder_load_model(&enc->lpcnet_enc_state, data, len);
}
+ if (ret == 0) enc->loaded = 1;
return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG;
}
@@ -74,8 +75,9 @@ void dred_encoder_init(DREDEnc* enc, opus_int32 Fs, int channels)
{
enc->Fs = Fs;
enc->channels = channels;
+ enc->loaded = 0;
#ifndef USE_WEIGHTS_FILE
- init_rdovaeenc(&enc->model, rdovaeenc_arrays);
+ if (init_rdovaeenc(&enc->model, rdovaeenc_arrays) == 0) enc->loaded = 1;
#endif
dred_encoder_reset(enc);
}
@@ -85,6 +87,7 @@ static void dred_process_frame(DREDEnc *enc)
float feature_buffer[2 * 36];
float input_buffer[2*DRED_NUM_FEATURES] = {0};
+ celt_assert(enc->loaded);
/* shift latents buffer */
OPUS_MOVE(enc->latents_buffer + DRED_LATENT_DIM, enc->latents_buffer, (DRED_MAX_FRAMES - 1) * DRED_LATENT_DIM);
@@ -184,6 +187,7 @@ void dred_compute_latents(DREDEnc *enc, const float *pcm, int frame_size, int ex
{
int curr_offset16k;
int frame_size16k = frame_size * 16000 / enc->Fs;
+ celt_assert(enc->loaded);
curr_offset16k = 40 + extra_delay*16000/enc->Fs - enc->input_buffer_fill;
enc->dred_offset = (int)floor((curr_offset16k+20.f)/40.f);
enc->latent_offset = 0;
diff --git a/silk/dred_encoder.h b/silk/dred_encoder.h
index 2b77d581..abeaac7f 100644
--- a/silk/dred_encoder.h
+++ b/silk/dred_encoder.h
@@ -40,6 +40,9 @@
typedef struct {
RDOVAEEnc model;
+ LPCNetEncState lpcnet_enc_state;
+ RDOVAEEncState rdovae_enc;
+ int loaded;
opus_int32 Fs;
int channels;
@@ -53,8 +56,6 @@ typedef struct {
float state_buffer[DRED_STATE_DIM];
float initial_state[DRED_STATE_DIM];
float resample_mem[RESAMPLING_ORDER + 1];
- LPCNetEncState lpcnet_enc_state;
- RDOVAEEncState rdovae_enc;
} DREDEnc;
int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len);
diff --git a/src/opus_decoder.c b/src/opus_decoder.c
index 67b1cfd3..999c6fe0 100644
--- a/src/opus_decoder.c
+++ b/src/opus_decoder.c
@@ -1042,7 +1042,7 @@ int opus_decoder_ctl(OpusDecoder *st, int request, ...)
{
goto bad_arg;
}
- return lpcnet_plc_load_model(&st->lpcnet, data, len);
+ ret = lpcnet_plc_load_model(&st->lpcnet, data, len);
}
break;
#endif
@@ -1156,6 +1156,7 @@ struct OpusDREDDecoder {
#ifdef ENABLE_DRED
RDOVAEDec model;
#endif
+ int loaded;
int arch;
opus_uint32 magic;
};
@@ -1188,19 +1189,23 @@ int dred_decoder_load_model(OpusDREDDecoder *dec, const unsigned char *data, int
parse_weights(&list, data, len);
ret = init_rdovaedec(&dec->model, list);
free(list);
+ if (ret == 0) dec->loaded = 1;
return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG;
}
#endif
int opus_dred_decoder_init(OpusDREDDecoder *dec)
{
+ int ret = 0;
+ dec->loaded = 0;
#if defined(ENABLE_DRED) && !defined(USE_WEIGHTS_FILE)
- init_rdovaedec(&dec->model, rdovaedec_arrays);
+ ret = init_rdovaedec(&dec->model, rdovaedec_arrays);
+ if (ret == 0) dec->loaded = 1;
#endif
dec->arch = opus_select_arch();
/* To make sure nobody forgets to init, use a magic number. */
dec->magic = 0xD8EDDEC0;
- return OPUS_OK;
+ return (ret == 0) ? OPUS_OK : OPUS_UNIMPLEMENTED;
}
OpusDREDDecoder *opus_dred_decoder_create(int *error)
@@ -1378,6 +1383,7 @@ int opus_dred_parse(OpusDREDDecoder *dred_dec, OpusDRED *dred, const unsigned ch
const unsigned char *payload;
opus_int32 payload_len;
VALIDATE_DRED_DECODER(dred_dec);
+ if (!dred_dec->loaded) return OPUS_UNIMPLEMENTED;
dred->process_stage = -1;
payload_len = dred_find_payload(data, len, &payload);
if (payload_len < 0)
@@ -1412,6 +1418,7 @@ int opus_dred_process(OpusDREDDecoder *dred_dec, const OpusDRED *src, OpusDRED *
if (dred_dec == NULL || src == NULL || dst == NULL || (src->process_stage != 1 && src->process_stage != 2))
return OPUS_BAD_ARG;
VALIDATE_DRED_DECODER(dred_dec);
+ if (!dred_dec->loaded) return OPUS_UNIMPLEMENTED;
if (src != dst)
OPUS_COPY(dst, src, 1);
if (dst->process_stage == 2)
diff --git a/src/opus_encoder.c b/src/opus_encoder.c
index 5ed4b187..27b3196a 100644
--- a/src/opus_encoder.c
+++ b/src/opus_encoder.c
@@ -1713,7 +1713,7 @@ opus_int32 opus_encode_native(OpusEncoder *st, const opus_val16 *pcm, int frame_
#endif
#ifdef ENABLE_DRED
- if ( st->dred_duration > 0 ) {
+ if ( st->dred_duration > 0 && st->dred_encoder.loaded ) {
/* DRED Encoder */
dred_compute_latents( &st->dred_encoder, &pcm_buf[total_buffer*st->channels], frame_size, total_buffer );
} else {
@@ -2255,7 +2255,7 @@ opus_int32 opus_encode_native(OpusEncoder *st, const opus_val16 *pcm, int frame_
ret += 1+redundancy_bytes;
apply_padding = !st->use_vbr;
#ifdef ENABLE_DRED
- if (st->dred_duration > 0) {
+ if (st->dred_duration > 0 && st->dred_encoder.loaded) {
opus_extension_data extension;
unsigned char buf[DRED_MAX_DATA_SIZE];
int dred_chunks;
@@ -2893,17 +2893,17 @@ int opus_encoder_ctl(OpusEncoder *st, int request, ...)
}
break;
#ifdef USE_WEIGHTS_FILE
- case OPUS_SET_DNN_BLOB_REQUEST:
- {
- const unsigned char *data = va_arg(ap, const unsigned char *);
- opus_int32 len = va_arg(ap, opus_int32);
- if(len<0 || data == NULL)
- {
- goto bad_arg;
- }
- return dred_encoder_load_model(&st->dred_encoder, data, len);
- }
- break;
+ case OPUS_SET_DNN_BLOB_REQUEST:
+ {
+ const unsigned char *data = va_arg(ap, const unsigned char *);
+ opus_int32 len = va_arg(ap, opus_int32);
+ if(len<0 || data == NULL)
+ {
+ goto bad_arg;
+ }
+ ret = dred_encoder_load_model(&st->dred_encoder, data, len);
+ }
+ break;
#endif
case CELT_GET_MODE_REQUEST:
{