Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-06-07 00:19:12 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-06-07 00:19:12 +0300
commitaab8a6f82c0b95fd7a09918656dabdf08f2c1044 (patch)
treee06806cba1a4286a441cc9dcb90484eeea4d3ac6
parent3170791311546cd9b38dd65d1595f5630ea50c8e (diff)
Add blob loading for DRED encoder and decoderexp_dred_refactor1
-rw-r--r--include/opus.h13
-rw-r--r--silk/dred_encoder.c10
-rw-r--r--silk/dred_encoder.h2
-rw-r--r--src/opus_decoder.c42
-rw-r--r--src/opus_demo.c8
-rw-r--r--src/opus_encoder.c13
6 files changed, 83 insertions, 5 deletions
diff --git a/include/opus.h b/include/opus.h
index 41244f23..a52daa22 100644
--- a/include/opus.h
+++ b/include/opus.h
@@ -547,7 +547,18 @@ OPUS_EXPORT int opus_dred_decoder_init(OpusDREDDecoder *dec);
*/
OPUS_EXPORT void opus_dred_decoder_destroy(OpusDREDDecoder *dec);
-
+/** Perform a CTL function on an Opus DRED decoder.
+ *
+ * Generally the request and subsequent arguments are generated
+ * by a convenience macro.
+ * @param st <tt>OpusDREDDecoder*</tt>: DRED Decoder state.
+ * @param request This and all remaining parameters should be replaced by one
+ * of the convenience macros in @ref opus_genericctls or
+ * @ref opus_decoderctls.
+ * @see opus_genericctls
+ * @see opus_decoderctls
+ */
+OPUS_EXPORT int opus_dred_decoder_ctl(OpusDREDDecoder *dred_dec, int request, ...);
/** Gets the size of an <code>OpusDRED</code> structure.
* @returns The size in bytes.
diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c
index afec129d..c2628c7f 100644
--- a/silk/dred_encoder.c
+++ b/silk/dred_encoder.c
@@ -44,6 +44,16 @@
#include "float_cast.h"
#include "os_support.h"
+int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len)
+{
+ WeightArray *list;
+ int ret;
+ parse_weights(&list, data, len);
+ ret = init_rdovaeenc(&enc->model, list);
+ free(list);
+ return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG;
+}
+
void dred_encoder_reset(DREDEnc* enc)
{
RNN_CLEAR((char*)&enc->DREDENC_RESET_START,
diff --git a/silk/dred_encoder.h b/silk/dred_encoder.h
index 30e639a9..439ef654 100644
--- a/silk/dred_encoder.h
+++ b/silk/dred_encoder.h
@@ -54,7 +54,7 @@ typedef struct {
RDOVAEEncState rdovae_enc;
} DREDEnc;
-
+int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len);
void dred_encoder_init(DREDEnc* enc, opus_int32 Fs, int channels);
void dred_encoder_reset(DREDEnc* enc);
diff --git a/src/opus_decoder.c b/src/opus_decoder.c
index d28a052a..51b78972 100644
--- a/src/opus_decoder.c
+++ b/src/opus_decoder.c
@@ -1141,6 +1141,16 @@ int opus_dred_decoder_get_size(void)
return sizeof(OpusDREDDecoder);
}
+int dred_decoder_load_model(OpusDREDDecoder *dec, const unsigned char *data, int len)
+{
+ WeightArray *list;
+ int ret;
+ parse_weights(&list, data, len);
+ ret = init_rdovaedec(&dec->model, list);
+ free(list);
+ return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG;
+}
+
int opus_dred_decoder_init(OpusDREDDecoder *dec)
{
#ifndef USE_WEIGHTS_FILE
@@ -1180,7 +1190,39 @@ void opus_dred_decoder_destroy(OpusDREDDecoder *dec)
free(dec);
}
+int opus_dred_decoder_ctl(OpusDREDDecoder *dred_dec, int request, ...)
+{
+ int ret = OPUS_OK;
+ va_list ap;
+
+ va_start(ap, request);
+ switch (request)
+ {
+#ifdef USE_WEIGHTS_FILE
+ case OPUS_SET_DNN_BLOB_REQUEST:
+ {
+ const unsigned char *data = va_arg(ap, const unsigned char *);
+ opus_int32 len = va_arg(ap, opus_int32);
+ if(len<0 || data == NULL)
+ {
+ goto bad_arg;
+ }
+ return dred_decoder_load_model(dred_dec, data, len);
+ }
+ break;
+#endif
+ default:
+ /*fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);*/
+ ret = OPUS_UNIMPLEMENTED;
+ break;
+ }
+ va_end(ap);
+ return ret;
+bad_arg:
+ va_end(ap);
+ return OPUS_BAD_ARG;
+}
#ifdef ENABLE_NEURAL_FEC
static int dred_find_payload(const unsigned char *data, opus_int32 len, const unsigned char **payload)
diff --git a/src/opus_demo.c b/src/opus_demo.c
index 563d3c5b..b48845c9 100644
--- a/src/opus_demo.c
+++ b/src/opus_demo.c
@@ -617,9 +617,6 @@ int main(int argc, char *argv[])
goto failure;
}
}
-#ifdef USE_WEIGHTS_FILE
- opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len));
-#endif
switch(bandwidth)
{
case OPUS_BANDWIDTH_NARROWBAND:
@@ -684,6 +681,11 @@ int main(int argc, char *argv[])
}
dred_dec = opus_dred_decoder_create(&err);
dred = opus_dred_alloc(&err);
+#ifdef USE_WEIGHTS_FILE
+ opus_encoder_ctl(enc, OPUS_SET_DNN_BLOB(blob_data, blob_len));
+ opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len));
+ opus_dred_decoder_ctl(dred_dec, OPUS_SET_DNN_BLOB(blob_data, blob_len));
+#endif
while (!stop)
{
if (delayed_celt)
diff --git a/src/opus_encoder.c b/src/opus_encoder.c
index 3738c848..f6d3bc58 100644
--- a/src/opus_encoder.c
+++ b/src/opus_encoder.c
@@ -2847,6 +2847,19 @@ int opus_encoder_ctl(OpusEncoder *st, int request, ...)
}
}
break;
+#ifdef USE_WEIGHTS_FILE
+ case OPUS_SET_DNN_BLOB_REQUEST:
+ {
+ const unsigned char *data = va_arg(ap, const unsigned char *);
+ opus_int32 len = va_arg(ap, opus_int32);
+ if(len<0 || data == NULL)
+ {
+ goto bad_arg;
+ }
+ return dred_encoder_load_model(&st->dred_encoder, data, len);
+ }
+ break;
+#endif
case CELT_GET_MODE_REQUEST:
{
const CELTMode ** value = va_arg(ap, const CELTMode**);