diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-11-09 01:32:43 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-11-09 01:32:43 +0300 |
commit | 77594bf158bae48c8267f1f548209caa118ae7d5 (patch) | |
tree | 93b221f0e5966cbb1f0250cd496c8652ca8b2c2f | |
parent | 222662dac8bfbc2d764142d178b91f9d928f56cc (diff) |
Dumping RDOVAE stats from XML
-rw-r--r-- | dnn/torch/rdovae/export_rdovae_weights.py | 33 |
1 files changed, 29 insertions, 4 deletions
diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py index 55093d76..a7585c9d 100644 --- a/dnn/torch/rdovae/export_rdovae_weights.py +++ b/dnn/torch/rdovae/export_rdovae_weights.py @@ -48,8 +48,27 @@ from rdovae import RDOVAE from wexchange.torch import dump_torch_weights from wexchange.c_export import CWriter, print_vector - -def dump_statistical_model(writer, w, name): +def print_xml(xmlout, val, param, anchor, name): + xmlout.write( +f""" + <table anchor="{anchor}_{name}"> + <name>{param} values for {name}</name> + <thead> + <tr><th>k</th><th>Q0</th><th>Q1</th><th>Q2</th><th>Q3</th><th>Q4</th><th>Q5</th><th>Q6</th><th>Q7</th><th>Q8</th><th>Q9</th><th>Q10</th><th>Q11</th><th>Q12</th><th>Q13</th><th>Q14</th><th>Q15</th></tr> + </thead> + <tbody> +""") + for k in range(val.shape[1]): + xmlout.write(f" <tr><th>{k}</th>") + for j in range(val.shape[0]): + xmlout.write(f"<th>{val[j][k]}</th>") + xmlout.write("</tr>\n") + xmlout.write( +f""" + </tbody> + </table> +""") +def dump_statistical_model(writer, w, name, xmlout): levels = w.shape[0] print("printing statistical model") @@ -78,6 +97,11 @@ def dump_statistical_model(writer, w, name): print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False) print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False) + print_xml(xmlout, quant_scales_q8, "Scale", "scale", name) + print_xml(xmlout, dead_zone_q8, "Dead zone", "deadzone", name) + print_xml(xmlout, r_q8, "Decay (r)", "decay", name) + print_xml(xmlout, p0_q8, "P(0)", "p0", name) + writer.header.write( f""" extern const opus_uint8 dred_{name}_quant_scales_q8[{levels * N}]; @@ -98,6 +122,7 @@ def c_export(args, model): dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message, model_struct_name='RDOVAEDec') stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats_data"), message=message, enable_binary_blob=False) constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True, enable_binary_blob=False) + xmlout = open("stats.xml", "w") # some custom includes for writer in [enc_writer, dec_writer]: @@ -130,8 +155,8 @@ f""" levels = qembedding.shape[0] qembedding = torch.reshape(qembedding, (levels, 6, -1)) - latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent') - state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state') + latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent', xmlout) + state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state', xmlout) padded_latent_dim = (latent_dim+7)//8*8 latent_pad = padded_latent_dim - latent_dim; |