Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-11-09 01:32:43 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-11-09 01:32:43 +0300
commit77594bf158bae48c8267f1f548209caa118ae7d5 (patch)
tree93b221f0e5966cbb1f0250cd496c8652ca8b2c2f
parent222662dac8bfbc2d764142d178b91f9d928f56cc (diff)
Dumping RDOVAE stats from XML
-rw-r--r--dnn/torch/rdovae/export_rdovae_weights.py33
1 files changed, 29 insertions, 4 deletions
diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py
index 55093d76..a7585c9d 100644
--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -48,8 +48,27 @@ from rdovae import RDOVAE
from wexchange.torch import dump_torch_weights
from wexchange.c_export import CWriter, print_vector
-
-def dump_statistical_model(writer, w, name):
+def print_xml(xmlout, val, param, anchor, name):
+ xmlout.write(
+f"""
+ <table anchor="{anchor}_{name}">
+ <name>{param} values for {name}</name>
+ <thead>
+ <tr><th>k</th><th>Q0</th><th>Q1</th><th>Q2</th><th>Q3</th><th>Q4</th><th>Q5</th><th>Q6</th><th>Q7</th><th>Q8</th><th>Q9</th><th>Q10</th><th>Q11</th><th>Q12</th><th>Q13</th><th>Q14</th><th>Q15</th></tr>
+ </thead>
+ <tbody>
+""")
+ for k in range(val.shape[1]):
+ xmlout.write(f" <tr><th>{k}</th>")
+ for j in range(val.shape[0]):
+ xmlout.write(f"<th>{val[j][k]}</th>")
+ xmlout.write("</tr>\n")
+ xmlout.write(
+f"""
+ </tbody>
+ </table>
+""")
+def dump_statistical_model(writer, w, name, xmlout):
levels = w.shape[0]
print("printing statistical model")
@@ -78,6 +97,11 @@ def dump_statistical_model(writer, w, name):
print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False)
print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False)
+ print_xml(xmlout, quant_scales_q8, "Scale", "scale", name)
+ print_xml(xmlout, dead_zone_q8, "Dead zone", "deadzone", name)
+ print_xml(xmlout, r_q8, "Decay (r)", "decay", name)
+ print_xml(xmlout, p0_q8, "P(0)", "p0", name)
+
writer.header.write(
f"""
extern const opus_uint8 dred_{name}_quant_scales_q8[{levels * N}];
@@ -98,6 +122,7 @@ def c_export(args, model):
dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message, model_struct_name='RDOVAEDec')
stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats_data"), message=message, enable_binary_blob=False)
constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True, enable_binary_blob=False)
+ xmlout = open("stats.xml", "w")
# some custom includes
for writer in [enc_writer, dec_writer]:
@@ -130,8 +155,8 @@ f"""
levels = qembedding.shape[0]
qembedding = torch.reshape(qembedding, (levels, 6, -1))
- latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent')
- state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state')
+ latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent', xmlout)
+ state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state', xmlout)
padded_latent_dim = (latent_dim+7)//8*8
latent_pad = padded_latent_dim - latent_dim;