diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-06-07 00:19:12 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-06-07 00:19:12 +0300 |
commit | aab8a6f82c0b95fd7a09918656dabdf08f2c1044 (patch) | |
tree | e06806cba1a4286a441cc9dcb90484eeea4d3ac6 | |
parent | 3170791311546cd9b38dd65d1595f5630ea50c8e (diff) |
Add blob loading for DRED encoder and decoderexp_dred_refactor1
-rw-r--r-- | include/opus.h | 13 | ||||
-rw-r--r-- | silk/dred_encoder.c | 10 | ||||
-rw-r--r-- | silk/dred_encoder.h | 2 | ||||
-rw-r--r-- | src/opus_decoder.c | 42 | ||||
-rw-r--r-- | src/opus_demo.c | 8 | ||||
-rw-r--r-- | src/opus_encoder.c | 13 |
6 files changed, 83 insertions, 5 deletions
diff --git a/include/opus.h b/include/opus.h index 41244f23..a52daa22 100644 --- a/include/opus.h +++ b/include/opus.h @@ -547,7 +547,18 @@ OPUS_EXPORT int opus_dred_decoder_init(OpusDREDDecoder *dec); */ OPUS_EXPORT void opus_dred_decoder_destroy(OpusDREDDecoder *dec); - +/** Perform a CTL function on an Opus DRED decoder. + * + * Generally the request and subsequent arguments are generated + * by a convenience macro. + * @param st <tt>OpusDREDDecoder*</tt>: DRED Decoder state. + * @param request This and all remaining parameters should be replaced by one + * of the convenience macros in @ref opus_genericctls or + * @ref opus_decoderctls. + * @see opus_genericctls + * @see opus_decoderctls + */ +OPUS_EXPORT int opus_dred_decoder_ctl(OpusDREDDecoder *dred_dec, int request, ...); /** Gets the size of an <code>OpusDRED</code> structure. * @returns The size in bytes. diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c index afec129d..c2628c7f 100644 --- a/silk/dred_encoder.c +++ b/silk/dred_encoder.c @@ -44,6 +44,16 @@ #include "float_cast.h" #include "os_support.h" +int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len) +{ + WeightArray *list; + int ret; + parse_weights(&list, data, len); + ret = init_rdovaeenc(&enc->model, list); + free(list); + return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG; +} + void dred_encoder_reset(DREDEnc* enc) { RNN_CLEAR((char*)&enc->DREDENC_RESET_START, diff --git a/silk/dred_encoder.h b/silk/dred_encoder.h index 30e639a9..439ef654 100644 --- a/silk/dred_encoder.h +++ b/silk/dred_encoder.h @@ -54,7 +54,7 @@ typedef struct { RDOVAEEncState rdovae_enc; } DREDEnc; - +int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len); void dred_encoder_init(DREDEnc* enc, opus_int32 Fs, int channels); void dred_encoder_reset(DREDEnc* enc); diff --git a/src/opus_decoder.c b/src/opus_decoder.c index d28a052a..51b78972 100644 --- a/src/opus_decoder.c +++ b/src/opus_decoder.c @@ -1141,6 +1141,16 @@ int opus_dred_decoder_get_size(void) return sizeof(OpusDREDDecoder); } +int dred_decoder_load_model(OpusDREDDecoder *dec, const unsigned char *data, int len) +{ + WeightArray *list; + int ret; + parse_weights(&list, data, len); + ret = init_rdovaedec(&dec->model, list); + free(list); + return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG; +} + int opus_dred_decoder_init(OpusDREDDecoder *dec) { #ifndef USE_WEIGHTS_FILE @@ -1180,7 +1190,39 @@ void opus_dred_decoder_destroy(OpusDREDDecoder *dec) free(dec); } +int opus_dred_decoder_ctl(OpusDREDDecoder *dred_dec, int request, ...) +{ + int ret = OPUS_OK; + va_list ap; + + va_start(ap, request); + switch (request) + { +#ifdef USE_WEIGHTS_FILE + case OPUS_SET_DNN_BLOB_REQUEST: + { + const unsigned char *data = va_arg(ap, const unsigned char *); + opus_int32 len = va_arg(ap, opus_int32); + if(len<0 || data == NULL) + { + goto bad_arg; + } + return dred_decoder_load_model(dred_dec, data, len); + } + break; +#endif + default: + /*fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);*/ + ret = OPUS_UNIMPLEMENTED; + break; + } + va_end(ap); + return ret; +bad_arg: + va_end(ap); + return OPUS_BAD_ARG; +} #ifdef ENABLE_NEURAL_FEC static int dred_find_payload(const unsigned char *data, opus_int32 len, const unsigned char **payload) diff --git a/src/opus_demo.c b/src/opus_demo.c index 563d3c5b..b48845c9 100644 --- a/src/opus_demo.c +++ b/src/opus_demo.c @@ -617,9 +617,6 @@ int main(int argc, char *argv[]) goto failure; } } -#ifdef USE_WEIGHTS_FILE - opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len)); -#endif switch(bandwidth) { case OPUS_BANDWIDTH_NARROWBAND: @@ -684,6 +681,11 @@ int main(int argc, char *argv[]) } dred_dec = opus_dred_decoder_create(&err); dred = opus_dred_alloc(&err); +#ifdef USE_WEIGHTS_FILE + opus_encoder_ctl(enc, OPUS_SET_DNN_BLOB(blob_data, blob_len)); + opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len)); + opus_dred_decoder_ctl(dred_dec, OPUS_SET_DNN_BLOB(blob_data, blob_len)); +#endif while (!stop) { if (delayed_celt) diff --git a/src/opus_encoder.c b/src/opus_encoder.c index 3738c848..f6d3bc58 100644 --- a/src/opus_encoder.c +++ b/src/opus_encoder.c @@ -2847,6 +2847,19 @@ int opus_encoder_ctl(OpusEncoder *st, int request, ...) } } break; +#ifdef USE_WEIGHTS_FILE + case OPUS_SET_DNN_BLOB_REQUEST: + { + const unsigned char *data = va_arg(ap, const unsigned char *); + opus_int32 len = va_arg(ap, opus_int32); + if(len<0 || data == NULL) + { + goto bad_arg; + } + return dred_encoder_load_model(&st->dred_encoder, data, len); + } + break; +#endif case CELT_GET_MODE_REQUEST: { const CELTMode ** value = va_arg(ap, const CELTMode**); |