Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-12-11 13:13:11 +0300
committerJan Buethe <jbuethe@amazon.de>2023-12-11 13:13:11 +0300
commit241781d4c0defd6214e2baef9a3376eb6d77d8a7 (patch)
tree2bc4c6bde578ea0370eaf57975c2fced5731006d
parent9de91f631c2147943ac70679d1f37239fb136a53 (diff)
parentd2f07aeb10636596d25990cfc1eb7b27f891e7bd (diff)
Merge remote-tracking branch 'origin/opus-ng-lace-integration2' into opus-ng-lace-integration2-model-tuningopus-ng-lace-integration2-model-tuning
-rw-r--r--dnn/osce.c2
-rw-r--r--dnn/osce.h6
-rw-r--r--dnn/osce_structs.h24
-rw-r--r--dnn/torch/osce/export_model_weights.py21
-rw-r--r--dnn/torch/weight-exchange/wexchange/torch/torch.py52
-rw-r--r--silk/API.h6
-rw-r--r--silk/control.h5
-rw-r--r--silk/dec_API.c25
-rw-r--r--silk/decode_frame.c4
-rw-r--r--silk/init_decoder.c29
-rw-r--r--silk/main.h4
-rw-r--r--silk/structs.h8
-rw-r--r--src/opus_decoder.c17
13 files changed, 167 insertions, 36 deletions
diff --git a/dnn/osce.c b/dnn/osce.c
index 503b5671..093f4f2b 100644
--- a/dnn/osce.c
+++ b/dnn/osce.c
@@ -833,7 +833,7 @@ void osce_enhance_frame(
/* enhancement only implemented for 20 ms frame at 16kHz */
if (psDec->fs_kHz != 16 || psDec->nb_subfr != 4)
{
- /* Question: reset state? */
+ osce_reset(&psDec->osce, psDec->osce.method);
return;
}
diff --git a/dnn/osce.h b/dnn/osce.h
index 98ce88c5..76463e1d 100644
--- a/dnn/osce.h
+++ b/dnn/osce.h
@@ -23,7 +23,13 @@
#define OSCE_METHOD_NOLACE 2
#endif
+#if !defined(DISABLE_NOLACE)
#define OSCE_DEFAULT_METHOD OSCE_METHOD_NOLACE
+#elif !defined(DISABLE_LACE)
+#define OSCE_DEFAULT_METHOD OSCE_METHOD_LACE
+#else
+#define OSCE_DEFAULT_METHOD OSCE_METHOD_NONE
+#endif
diff --git a/dnn/osce_structs.h b/dnn/osce_structs.h
index 4f1720dd..67fbf811 100644
--- a/dnn/osce_structs.h
+++ b/dnn/osce_structs.h
@@ -3,8 +3,12 @@
#include "opus_types.h"
#include "osce_config.h"
+#ifndef DISABLE_LACE
#include "lace_data.h"
+#endif
+#ifndef DISABLE_NOLACE
#include "nolace_data.h"
+#endif
#include "nndsp.h"
#include "nnet.h"
@@ -18,8 +22,9 @@ typedef struct {
float signal_history[OSCE_FEATURES_MAX_HISTORY];
} OSCEFeatureState;
-/* LACE */
+#ifndef DISABLE_LACE
+/* LACE */
typedef struct {
float feature_net_conv2_state[LACE_FEATURE_NET_CONV2_STATE_SIZE];
float feature_net_gru_state[LACE_COND_DIM]; /* ToDo: fix! */
@@ -33,12 +38,14 @@ typedef struct {
typedef struct
{
LACELayers layers;
- LACEState state;
float window[LACE_OVERLAP_SIZE];
} LACE;
-/* NoLACE */
+#endif /* #ifndef DISABLE_LACE */
+
+#ifndef DISABLE_NOLACE
+/* NoLACE */
typedef struct {
float feature_net_conv2_state[NOLACE_FEATURE_NET_CONV2_STATE_SIZE];
float feature_net_gru_state[NOLACE_COND_DIM];
@@ -62,19 +69,28 @@ typedef struct {
typedef struct {
NOLACELayers layers;
- NoLACEState state;
float window[LACE_OVERLAP_SIZE];
} NoLACE;
+#endif /* #ifndef DISABLE_NOLACE */
+
/* OSCEModel */
typedef struct {
+#ifndef DISABLE_LACE
LACE lace;
+#endif
+#ifndef DISABLE_NOLACE
NoLACE nolace;
+#endif
} OSCEModel;
typedef union {
+#ifndef DISABLE_LACE
LACEState lace;
+#endif
+#ifndef DISABLE_NOLACE
NoLACEState nolace;
+#endif
} OSCEState;
#endif \ No newline at end of file
diff --git a/dnn/torch/osce/export_model_weights.py b/dnn/torch/osce/export_model_weights.py
index c3b723c7..786d3200 100644
--- a/dnn/torch/osce/export_model_weights.py
+++ b/dnn/torch/osce/export_model_weights.py
@@ -51,6 +51,7 @@ parser = argparse.ArgumentParser()
parser.add_argument('checkpoint', type=str, help='LACE or NoLACE model checkpoint')
parser.add_argument('output_dir', type=str, help='output folder')
+parser.add_argument('--quantize', action="store_true", help='quantization according to schedule')
schedules = {
@@ -60,15 +61,15 @@ schedules = {
('feature_net.conv2', dict(quantize=True, scale=None)),
('feature_net.tconv', dict(quantize=True, scale=None)),
('feature_net.gru', dict()),
- ('cf1', dict()),
- ('cf2', dict()),
- ('af1', dict()),
+ ('cf1', dict(quantize=True, scale=None)),
+ ('cf2', dict(quantize=True, scale=None)),
+ ('af1', dict(quantize=True, scale=None)),
('tdshape1', dict()),
('tdshape2', dict()),
('tdshape3', dict()),
- ('af2', dict()),
- ('af3', dict()),
- ('af4', dict()),
+ ('af2', dict(quantize=True, scale=None)),
+ ('af3', dict(quantize=True, scale=None)),
+ ('af4', dict(quantize=True, scale=None)),
('post_cf1', dict(quantize=True, scale=None)),
('post_cf2', dict(quantize=True, scale=None)),
('post_af1', dict(quantize=True, scale=None)),
@@ -81,9 +82,9 @@ schedules = {
('feature_net.conv2', dict(quantize=True, scale=None)),
('feature_net.tconv', dict(quantize=True, scale=None)),
('feature_net.gru', dict()),
- ('cf1', dict()),
- ('cf2', dict()),
- ('af1', dict())
+ ('cf1', dict(quantize=True, scale=None)),
+ ('cf2', dict(quantize=True, scale=None)),
+ ('af1', dict(quantize=True, scale=None))
]
}
@@ -161,7 +162,7 @@ if __name__ == "__main__":
cwriter.header.write(f"#define {model_name.upper()}_NUMBITS_SCALE_{i} {float(s.detach().cpu())}\n")
# dump layers
- if model_name in schedules:
+ if model_name in schedules and args.quantize:
osce_scheduled_dump(cwriter, model_name, model, schedules[model_name])
else:
osce_dump_generic(cwriter, model_name, model)
diff --git a/dnn/torch/weight-exchange/wexchange/torch/torch.py b/dnn/torch/weight-exchange/wexchange/torch/torch.py
index 2dcee1c5..00a1a4bb 100644
--- a/dnn/torch/weight-exchange/wexchange/torch/torch.py
+++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py
@@ -45,7 +45,7 @@ except:
from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer, print_tconv1d_layer, print_conv2d_layer
-def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', kernel_scale=1/128, kernel_quantize=False, gain_scale=1/128, gain_quantize=False):
+def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', scale=1/128, quantize=False):
w_kernel = adaconv.conv_kernel.weight.detach().cpu().numpy().copy()
@@ -54,14 +54,34 @@ def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', kernel_sc
b_gain = adaconv.filter_gain.bias.detach().cpu().numpy().copy()
if isinstance(where, CWriter):
+ # pad kernel for quantization
+ left_padding = adaconv.padding[0]
+ kernel_size = adaconv.kernel_size
+ in_channels = adaconv.in_channels
+ out_channels = adaconv.out_channels
+ feature_dim = adaconv.feature_dim
+
+ if quantize and kernel_size % 8:
+ kernel_padding = 8 - (kernel_size % 8)
+ w_kernel = np.concatenate(
+ (np.zeros((out_channels, in_channels, kernel_padding, feature_dim)), w_kernel.reshape(out_channels, in_channels, kernel_size, feature_dim)),
+ dtype=w_kernel.dtype,
+ axis=2).reshape(-1, feature_dim)
+ b_kernel = np.concatenate(
+ (np.zeros((out_channels, in_channels, kernel_padding)), b_kernel.reshape(out_channels, in_channels, kernel_size)),
+ dtype=b_kernel.dtype,
+ axis=2).reshape(-1)
+ left_padding += kernel_padding
+ kernel_size += kernel_padding
+
# write relevant scalar parameters to header file
where.header.write(f"""
#define {name.upper()}_FILTER_GAIN_A {adaconv.filter_gain_a:f}f
#define {name.upper()}_FILTER_GAIN_B {adaconv.filter_gain_b:f}f
#define {name.upper()}_SHAPE_GAIN {adaconv.shape_gain:f}f
-#define {name.upper()}_KERNEL_SIZE {adaconv.kernel_size}
+#define {name.upper()}_KERNEL_SIZE {kernel_size}
#define {name.upper()}_FRAME_SIZE {adaconv.frame_size}
-#define {name.upper()}_LEFT_PADDING {adaconv.padding[0]}
+#define {name.upper()}_LEFT_PADDING {left_padding}
#define {name.upper()}_OVERLAP_SIZE {adaconv.overlap_size}
#define {name.upper()}_IN_CHANNELS {adaconv.in_channels}
#define {name.upper()}_OUT_CHANNELS {adaconv.out_channels}
@@ -70,8 +90,8 @@ def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', kernel_sc
"""
)
- print_dense_layer(where, name + "_kernel", w_kernel, b_kernel, scale=kernel_scale, format='torch', sparse=False, diagonal=False, quantize=kernel_quantize)
- print_dense_layer(where, name + "_gain", w_gain, b_gain, scale=gain_scale, format='torch', sparse=False, diagonal=False, quantize=gain_quantize)
+ print_dense_layer(where, name + "_kernel", w_kernel, b_kernel, scale=scale, format='torch', sparse=False, diagonal=False, quantize=quantize)
+ print_dense_layer(where, name + "_gain", w_gain, b_gain, format='torch', sparse=False, diagonal=False, quantize=False)
else:
@@ -81,7 +101,7 @@ def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', kernel_sc
np.save(where, 'bias_gain.npy', b_gain)
-def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', kernel_scale=1/128, kernel_quantize=False, gain_scale=1/128, gain_quantize=False, global_gain_scale=1/128, global_gain_quantize=False):
+def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', scale=1/128, quantize=False):
w_kernel = adaconv.conv_kernel.weight.detach().cpu().numpy().copy()
@@ -93,13 +113,23 @@ def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', kernel_sc
if isinstance(where, CWriter):
+ # pad kernel for quantization
+ left_padding = adaconv.padding[0]
+ kernel_size = adaconv.kernel_size
+
+ if quantize and w_kernel.shape[0] % 8:
+ kernel_padding = 8 - (w_kernel.shape[0] % 8)
+ w_kernel = np.concatenate((np.zeros((kernel_padding, w_kernel.shape[1])), w_kernel), dtype=w_kernel.dtype)
+ b_kernel = np.concatenate((np.zeros((kernel_padding)), b_kernel), dtype=b_kernel.dtype)
+ left_padding += kernel_padding
+ kernel_size += kernel_padding
# write relevant scalar parameters to header file
where.header.write(f"""
#define {name.upper()}_FILTER_GAIN_A {adaconv.filter_gain_a:f}f
#define {name.upper()}_FILTER_GAIN_B {adaconv.filter_gain_b:f}f
#define {name.upper()}_LOG_GAIN_LIMIT {adaconv.log_gain_limit:f}f
-#define {name.upper()}_KERNEL_SIZE {adaconv.kernel_size}
-#define {name.upper()}_LEFT_PADDING {adaconv.padding[0]}
+#define {name.upper()}_KERNEL_SIZE {kernel_size}
+#define {name.upper()}_LEFT_PADDING {left_padding}
#define {name.upper()}_FRAME_SIZE {adaconv.frame_size}
#define {name.upper()}_OVERLAP_SIZE {adaconv.overlap_size}
#define {name.upper()}_IN_CHANNELS {adaconv.in_channels}
@@ -110,9 +140,9 @@ def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', kernel_sc
"""
)
- print_dense_layer(where, name + "_kernel", w_kernel, b_kernel, scale=kernel_scale, format='torch', sparse=False, diagonal=False, quantize=kernel_quantize)
- print_dense_layer(where, name + "_gain", w_gain, b_gain, scale=gain_scale, format='torch', sparse=False, diagonal=False, quantize=gain_quantize)
- print_dense_layer(where, name + "_global_gain", w_global_gain, b_global_gain, scale=global_gain_scale, format='torch', sparse=False, diagonal=False, quantize=global_gain_quantize)
+ print_dense_layer(where, name + "_kernel", w_kernel, b_kernel, scale=scale, format='torch', sparse=False, diagonal=False, quantize=quantize)
+ print_dense_layer(where, name + "_gain", w_gain, b_gain, format='torch', sparse=False, diagonal=False, quantize=False)
+ print_dense_layer(where, name + "_global_gain", w_global_gain, b_global_gain, format='torch', sparse=False, diagonal=False, quantize=False)
else:
diff --git a/silk/API.h b/silk/API.h
index 6e623b84..29b6165f 100644
--- a/silk/API.h
+++ b/silk/API.h
@@ -100,8 +100,12 @@ opus_int silk_Get_Decoder_Size( /* O Returns error co
);
/*************************/
-/* Init or Reset decoder */
+/* Init and Reset decoder */
/*************************/
+opus_int silk_ResetDecoder( /* O Returns error code */
+ void *decState /* I/O State */
+);
+
opus_int silk_InitDecoder( /* O Returns error code */
void *decState /* I/O State */
);
diff --git a/silk/control.h b/silk/control.h
index d30d114c..f5633e62 100644
--- a/silk/control.h
+++ b/silk/control.h
@@ -147,6 +147,11 @@ typedef struct {
/* I: Enable Deep PLC */
opus_int enable_deep_plc;
+
+#ifdef ENABLE_OSCE
+ /* I: OSCE method */
+ opus_int osce_method;
+#endif
} silk_DecControlStruct;
#ifdef __cplusplus
diff --git a/silk/dec_API.c b/silk/dec_API.c
index a29ecc73..81c433d8 100644
--- a/silk/dec_API.c
+++ b/silk/dec_API.c
@@ -33,6 +33,10 @@ POSSIBILITY OF SUCH DAMAGE.
#include "stack_alloc.h"
#include "os_support.h"
+#ifdef ENABLE_OSCE
+#include "osce.h"
+#endif
+
/************************/
/* Decoder Super Struct */
/************************/
@@ -60,6 +64,24 @@ opus_int silk_Get_Decoder_Size( /* O Returns error co
}
/* Reset decoder state */
+opus_int silk_ResetDecoder( /* O Returns error code */
+ void *decState /* I/O State */
+)
+{
+ opus_int n, ret = SILK_NO_ERROR;
+ silk_decoder_state *channel_state = ((silk_decoder *)decState)->channel_state;
+
+ for( n = 0; n < DECODER_NUM_CHANNELS; n++ ) {
+ ret = silk_reset_decoder( &channel_state[ n ] );
+ }
+ silk_memset(&((silk_decoder *)decState)->sStereo, 0, sizeof(((silk_decoder *)decState)->sStereo));
+ /* Not strictly needed, but it's cleaner that way */
+ ((silk_decoder *)decState)->prev_decode_only_middle = 0;
+
+ return ret;
+}
+
+
opus_int silk_InitDecoder( /* O Returns error code */
void *decState /* I/O State */
)
@@ -301,6 +323,9 @@ opus_int silk_Decode( /* O Returns error co
} else {
condCoding = CODE_CONDITIONALLY;
}
+#ifdef ENABLE_OSCE
+ if (channel_state[n].osce.method != decControl->osce_method) {osce_reset(&channel_state[n].osce, decControl->osce_method);}
+#endif
ret += silk_decode_frame( &channel_state[ n ], psRangeDec, &samplesOut1_tmp[ n ][ 2 ], &nSamplesOutDec, lostFlag, condCoding,
#ifdef ENABLE_DEEP_PLC
n == 0 ? lpcnet : NULL,
diff --git a/silk/decode_frame.c b/silk/decode_frame.c
index 80283a38..d35aac08 100644
--- a/silk/decode_frame.c
+++ b/silk/decode_frame.c
@@ -140,6 +140,10 @@ opus_int silk_decode_frame(
/********************************************************/
osce_enhance_frame( psDec, psDecCtrl, pOut, ec_tell(psRangeDec) - ec_start, arch );
}
+ else
+ {
+ osce_reset( &psDec->osce, psDec->osce.method );
+ }
#endif
/************************************************/
diff --git a/silk/init_decoder.c b/silk/init_decoder.c
index 8e873c60..2dbda429 100644
--- a/silk/init_decoder.c
+++ b/silk/init_decoder.c
@@ -35,15 +35,17 @@ POSSIBILITY OF SUCH DAMAGE.
#include "osce.h"
#endif
+#include "struct.h"
+
/************************/
-/* Init Decoder State */
+/* Reset Decoder State */
/************************/
-opus_int silk_init_decoder(
+opus_int silk_reset_decoder(
silk_decoder_state *psDec /* I/O Decoder state pointer */
)
{
/* Clear the entire encoder state, except anything copied */
- silk_memset( psDec, 0, sizeof( silk_decoder_state ) );
+ silk_memset( &psDec->SILK_DECODER_STATE_RESET_START, 0, sizeof( silk_decoder_state ) - ((char*) &psDec->SILK_DECODER_STATE_RESET_START - (char*)psDec) );
/* Used to deactivate LSF interpolation */
psDec->first_frame_after_reset = 1;
@@ -57,9 +59,30 @@ opus_int silk_init_decoder(
silk_PLC_Reset( psDec );
#ifdef ENABLE_OSCE
+ /* Reset OSCE state and method */
+ osce_reset(&psDec->osce, OSCE_DEFAULT_METHOD);
+#endif
+
+ return 0;
+}
+
+
+/************************/
+/* Init Decoder State */
+/************************/
+opus_int silk_init_decoder(
+ silk_decoder_state *psDec /* I/O Decoder state pointer */
+)
+{
+ /* Clear the entire encoder state, except anything copied */
+ silk_memset( psDec, 0, sizeof( silk_decoder_state ) );
+
+#ifdef ENABLE_OSCE
osce_init(&psDec->osce, OSCE_DEFAULT_METHOD, NULL);
#endif
+ silk_reset_decoder( psDec );
+
return(0);
}
diff --git a/silk/main.h b/silk/main.h
index c67775ef..d5cb2a6e 100644
--- a/silk/main.h
+++ b/silk/main.h
@@ -389,6 +389,10 @@ void silk_NLSF_decode(
/****************************************************/
/* Decoder Functions */
/****************************************************/
+opus_int silk_reset_decoder(
+ silk_decoder_state *psDec /* I/O Decoder state pointer */
+);
+
opus_int silk_init_decoder(
silk_decoder_state *psDec /* I/O Decoder state pointer */
);
diff --git a/silk/structs.h b/silk/structs.h
index 5f912edf..123b384f 100644
--- a/silk/structs.h
+++ b/silk/structs.h
@@ -284,6 +284,10 @@ typedef struct {
/* Decoder state */
/********************************/
typedef struct {
+#ifdef ENABLE_OSCE
+ silk_OSCE_struct osce;
+#endif
+#define SILK_DECODER_STATE_RESET_START prev_gain_Q16
opus_int32 prev_gain_Q16;
opus_int32 exc_Q14[ MAX_FRAME_LENGTH ];
opus_int32 sLPC_Q14_buf[ MAX_LPC_ORDER ];
@@ -324,10 +328,6 @@ typedef struct {
/* CNG state */
silk_CNG_struct sCNG;
-#ifdef ENABLE_OSCE
- silk_OSCE_struct osce;
-#endif
-
/* Stuff used for PLC */
opus_int lossCnt;
opus_int prevSignalType;
diff --git a/src/opus_decoder.c b/src/opus_decoder.c
index 1e0a1da4..44a183d4 100644
--- a/src/opus_decoder.c
+++ b/src/opus_decoder.c
@@ -57,6 +57,10 @@
#include "dred_rdovae_dec.h"
#endif
+#ifdef ENABLE_OSCE
+#include "osce.h"
+#endif
+
struct OpusDecoder {
int celt_dec_offset;
int silk_dec_offset;
@@ -383,7 +387,7 @@ static int opus_decode_frame(OpusDecoder *st, const unsigned char *data,
pcm_ptr = pcm_silk;
if (st->prev_mode==MODE_CELT_ONLY)
- silk_InitDecoder( silk_dec );
+ silk_ResetDecoder( silk_dec );
/* The SILK PLC cannot produce frames of less than 10 ms */
st->DecControl.payloadSize_ms = IMAX(10, 1000 * audiosize / st->Fs);
@@ -408,6 +412,15 @@ static int opus_decode_frame(OpusDecoder *st, const unsigned char *data,
}
}
st->DecControl.enable_deep_plc = st->complexity >= 5;
+#ifdef ENABLE_OSCE
+ st->DecControl.osce_method = OSCE_METHOD_NONE;
+#ifndef DISABLE_LACE
+ if (st->complexity >= 2) {st->DecControl.osce_method = OSCE_METHOD_LACE;}
+#endif
+#ifndef DISABLE_NOLACE
+ if (st->complexity >= 5) {st->DecControl.osce_method = OSCE_METHOD_NOLACE;}
+#endif
+#endif
lost_flag = data == NULL ? 1 : 2 * !!decode_fec;
decoded_samples = 0;
@@ -953,7 +966,7 @@ int opus_decoder_ctl(OpusDecoder *st, int request, ...)
((char*)&st->OPUS_DECODER_RESET_START - (char*)st));
celt_decoder_ctl(celt_dec, OPUS_RESET_STATE);
- silk_InitDecoder( silk_dec );
+ silk_ResetDecoder( silk_dec );
st->stream_channels = st->channels;
st->frame_size = st->Fs/400;
#ifdef ENABLE_DEEP_PLC