diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-06-07 00:19:12 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-06-16 20:02:26 +0300 |
commit | a8cb719d05b2a45d4cf12ca1a61755cc5d904e55 (patch) | |
tree | 61b07d55f6d33ff825c0714c2ca0d24c8f7247d7 /src | |
parent | 0dad5e06abb91adc4737160a30e51c43cd8742c2 (diff) |
Add blob loading for DRED encoder and decoder
Diffstat (limited to 'src')
-rw-r--r-- | src/opus_decoder.c | 50 | ||||
-rw-r--r-- | src/opus_demo.c | 8 | ||||
-rw-r--r-- | src/opus_encoder.c | 13 |
3 files changed, 68 insertions, 3 deletions
diff --git a/src/opus_decoder.c b/src/opus_decoder.c index d28a052a..aad378f0 100644 --- a/src/opus_decoder.c +++ b/src/opus_decoder.c @@ -1141,6 +1141,16 @@ int opus_dred_decoder_get_size(void) return sizeof(OpusDREDDecoder); } +int dred_decoder_load_model(OpusDREDDecoder *dec, const unsigned char *data, int len) +{ + WeightArray *list; + int ret; + parse_weights(&list, data, len); + ret = init_rdovaedec(&dec->model, list); + free(list); + return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG; +} + int opus_dred_decoder_init(OpusDREDDecoder *dec) { #ifndef USE_WEIGHTS_FILE @@ -1180,7 +1190,47 @@ void opus_dred_decoder_destroy(OpusDREDDecoder *dec) free(dec); } +int opus_dred_decoder_ctl(OpusDREDDecoder *dred_dec, int request, ...) +{ +#ifdef ENABLE_NEURAL_FEC + int ret = OPUS_OK; + va_list ap; + va_start(ap, request); + (void)dred_dec; + switch (request) + { +# ifdef USE_WEIGHTS_FILE + case OPUS_SET_DNN_BLOB_REQUEST: + { + const unsigned char *data = va_arg(ap, const unsigned char *); + opus_int32 len = va_arg(ap, opus_int32); + if(len<0 || data == NULL) + { + goto bad_arg; + } + return dred_decoder_load_model(dred_dec, data, len); + } + break; +# endif + default: + /*fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);*/ + ret = OPUS_UNIMPLEMENTED; + break; + } + va_end(ap); + return ret; +# ifdef USE_WEIGHTS_FILE +bad_arg: + va_end(ap); + return OPUS_BAD_ARG; +# endif +#else + (void)dred_dec; + (void)request; + return OPUS_UNIMPLEMENTED; +#endif +} #ifdef ENABLE_NEURAL_FEC static int dred_find_payload(const unsigned char *data, opus_int32 len, const unsigned char **payload) diff --git a/src/opus_demo.c b/src/opus_demo.c index 563d3c5b..b48845c9 100644 --- a/src/opus_demo.c +++ b/src/opus_demo.c @@ -617,9 +617,6 @@ int main(int argc, char *argv[]) goto failure; } } -#ifdef USE_WEIGHTS_FILE - opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len)); -#endif switch(bandwidth) { case OPUS_BANDWIDTH_NARROWBAND: @@ -684,6 +681,11 @@ int main(int argc, char *argv[]) } dred_dec = opus_dred_decoder_create(&err); dred = opus_dred_alloc(&err); +#ifdef USE_WEIGHTS_FILE + opus_encoder_ctl(enc, OPUS_SET_DNN_BLOB(blob_data, blob_len)); + opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len)); + opus_dred_decoder_ctl(dred_dec, OPUS_SET_DNN_BLOB(blob_data, blob_len)); +#endif while (!stop) { if (delayed_celt) diff --git a/src/opus_encoder.c b/src/opus_encoder.c index 3738c848..f6d3bc58 100644 --- a/src/opus_encoder.c +++ b/src/opus_encoder.c @@ -2847,6 +2847,19 @@ int opus_encoder_ctl(OpusEncoder *st, int request, ...) } } break; +#ifdef USE_WEIGHTS_FILE + case OPUS_SET_DNN_BLOB_REQUEST: + { + const unsigned char *data = va_arg(ap, const unsigned char *); + opus_int32 len = va_arg(ap, opus_int32); + if(len<0 || data == NULL) + { + goto bad_arg; + } + return dred_encoder_load_model(&st->dred_encoder, data, len); + } + break; +#endif case CELT_GET_MODE_REQUEST: { const CELTMode ** value = va_arg(ap, const CELTMode**); |