diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-11-08 01:46:38 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-11-08 02:10:50 +0300 |
commit | 222662dac8bfbc2d764142d178b91f9d928f56cc (patch) | |
tree | 377c0fb53ac7238c43d433b28d1dbf2a5950f26b | |
parent | 4e104555e98c8227464f02ee388d983d387612b6 (diff) |
DRED: quantize scale and dead zone to 8 bits
-rwxr-xr-x | autogen.sh | 2 | ||||
-rw-r--r-- | dnn/torch/rdovae/export_rdovae_weights.py | 28 | ||||
-rw-r--r-- | silk/dred_decoder.c | 2 | ||||
-rw-r--r-- | silk/dred_encoder.c | 8 |
4 files changed, 23 insertions, 17 deletions
@@ -9,7 +9,7 @@ set -e srcdir=`dirname $0` test -n "$srcdir" && cd "$srcdir" -dnn/download_model.sh 2386a60 +dnn/download_model.sh b6095cf echo "Updating build configuration files, please wait...." diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py index 001999c6..55093d76 100644 --- a/dnn/torch/rdovae/export_rdovae_weights.py +++ b/dnn/torch/rdovae/export_rdovae_weights.py @@ -59,33 +59,35 @@ def dump_statistical_model(writer, w, name): p0 = torch.sigmoid(w[:, 4 , :]).numpy() p0 = 1 - r ** (0.5 + 0.5 * p0) + scales_norm = 255./256./(1e-15+np.max(quant_scales,axis=0)) + quant_scales = quant_scales*scales_norm quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16) - dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16) + dead_zone_q8 = np.clip(np.round(dead_zone * 2**8), 0, 255).astype(np.uint16) r_q8 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8) p0_q8 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16) mask = (np.max(r_q8,axis=0) > 0) * (np.min(p0_q8,axis=0) < 255) quant_scales_q8 = quant_scales_q8[:, mask] - dead_zone_q10 = dead_zone_q10[:, mask] + dead_zone_q8 = dead_zone_q8[:, mask] r_q8 = r_q8[:, mask] p0_q8 = p0_q8[:, mask] N = r_q8.shape[-1] - print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint16', static=False) - print_vector(writer.source, dead_zone_q10, f'dred_{name}_dead_zone_q10', dtype='opus_uint16', static=False) + print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint8', static=False) + print_vector(writer.source, dead_zone_q8, f'dred_{name}_dead_zone_q8', dtype='opus_uint8', static=False) print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False) print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False) writer.header.write( f""" -extern const opus_uint16 dred_{name}_quant_scales_q8[{levels * N}]; -extern const opus_uint16 dred_{name}_dead_zone_q10[{levels * N}]; +extern const opus_uint8 dred_{name}_quant_scales_q8[{levels * N}]; +extern const opus_uint8 dred_{name}_dead_zone_q8[{levels * N}]; extern const opus_uint8 dred_{name}_r_q8[{levels * N}]; extern const opus_uint8 dred_{name}_p0_q8[{levels * N}]; """ ) - return N, mask + return N, mask, torch.tensor(scales_norm[mask]) def c_export(args, model): @@ -128,14 +130,16 @@ f""" levels = qembedding.shape[0] qembedding = torch.reshape(qembedding, (levels, 6, -1)) - latent_dim, latent_mask = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent') - state_dim, state_mask = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state') + latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent') + state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state') padded_latent_dim = (latent_dim+7)//8*8 latent_pad = padded_latent_dim - latent_dim; w = latent_out.weight[latent_mask,:] + w = w/latent_scale[:, None] w = torch.cat([w, torch.zeros(latent_pad, w.shape[1])], dim=0) b = latent_out.bias[latent_mask] + b = b/latent_scale b = torch.cat([b, torch.zeros(latent_pad)], dim=0) latent_out.weight = torch.nn.Parameter(w) latent_out.bias = torch.nn.Parameter(b) @@ -143,16 +147,18 @@ f""" padded_state_dim = (state_dim+7)//8*8 state_pad = padded_state_dim - state_dim; w = state_out.weight[state_mask,:] + w = w/state_scale[:, None] w = torch.cat([w, torch.zeros(state_pad, w.shape[1])], dim=0) b = state_out.bias[state_mask] + b = b/state_scale b = torch.cat([b, torch.zeros(state_pad)], dim=0) state_out.weight = torch.nn.Parameter(w) state_out.bias = torch.nn.Parameter(b) latent_in = model.get_submodule('core_decoder.module.dense_1') state_in = model.get_submodule('core_decoder.module.hidden_init') - latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask]) - state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask]) + latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask]*latent_scale) + state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask]*state_scale) # encoder encoder_dense_layers = [ diff --git a/silk/dred_decoder.c b/silk/dred_decoder.c index c1489f3c..0d22f468 100644 --- a/silk/dred_decoder.c +++ b/silk/dred_decoder.c @@ -45,7 +45,7 @@ static int sign_extend(int x, int b) { return (x ^ m) - m; } -static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint8 *r, const opus_uint8 *p0, int dim) { +static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint8 *scale, const opus_uint8 *r, const opus_uint8 *p0, int dim) { int i; for (i=0;i<dim;i++) { int q; diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c index fb184103..b567a223 100644 --- a/silk/dred_encoder.c +++ b/silk/dred_encoder.c @@ -223,7 +223,7 @@ void dred_compute_latents(DREDEnc *enc, const float *pcm, int frame_size, int ex } } -static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint8 *r, const opus_uint8 *p0, int dim) { +static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint8 *scale, const opus_uint8 *dzone, const opus_uint8 *r, const opus_uint8 *p0, int dim) { int i; int q[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)]; float xq[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)]; @@ -233,7 +233,7 @@ static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 * /* This is split into multiple loops (with temporary arrays) so that the compiler can vectorize all of it, and so we can call the vector tanh(). */ for (i=0;i<dim;i++) { - delta[i] = dzone[i]*(1.f/1024.f); + delta[i] = dzone[i]*(1.f/256.f); xq[i] = x[i]*scale[i]*(1.f/256.f); deadzone[i] = xq[i]/(delta[i]+eps); } @@ -272,7 +272,7 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk &ec_encoder, enc->initial_state, dred_state_quant_scales_q8 + state_qoffset, - dred_state_dead_zone_q10 + state_qoffset, + dred_state_dead_zone_q8 + state_qoffset, dred_state_r_q8 + state_qoffset, dred_state_p0_q8 + state_qoffset, DRED_STATE_DIM); @@ -291,7 +291,7 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk &ec_encoder, enc->latents_buffer + (i+enc->latent_offset) * DRED_LATENT_DIM, dred_latent_quant_scales_q8 + offset, - dred_latent_dead_zone_q10 + offset, + dred_latent_dead_zone_q8 + offset, dred_latent_r_q8 + offset, dred_latent_p0_q8 + offset, DRED_LATENT_DIM |