diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-10-12 01:27:55 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-10-12 01:27:55 +0300 |
commit | 90c373be08720415a6ae9e01e6dcd46faf82a8e5 (patch) | |
tree | c059a803412b0aeb1a623c895db80ad76bbdef91 | |
parent | ea423519a0dded757474657580360d7870f2cb5a (diff) |
oops, add dumping scriptexp_fargancode1
-rw-r--r-- | dnn/torch/fargan/dump_fargan_weights.py | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/dnn/torch/fargan/dump_fargan_weights.py b/dnn/torch/fargan/dump_fargan_weights.py new file mode 100644 index 00000000..3dfcdb9c --- /dev/null +++ b/dnn/torch/fargan/dump_fargan_weights.py @@ -0,0 +1,110 @@ +import os +import sys +import argparse + +import torch +from torch import nn + + +sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange')) +import wexchange.torch + +import fargan +#from models import model_dict + +unquantized = [ + 'cond_net.pembed', + 'cond_net.fdense1', + 'cond_net.fconv1', + 'cond_net.fconv2', + 'cont_net.0', + 'sig_net.cond_gain_dense', + 'sig_net.fwc0.conv', + 'sig_net.fwc0.glu.gate', + 'sig_net.dense1_glu.gate', + 'sig_net.gru1_glu.gate', + 'sig_net.gru2_glu.gate', + 'sig_net.gru3_glu.gate', + 'sig_net.skip_glu.gate', + 'sig_net.skip_dense', + 'sig_net.sig_dense_out', + 'sig_net.gain_dense_out' +] + +description=f""" +This is an unsafe dumping script for FARGAN models. It assumes that all weights are included in Linear, Conv1d or GRU layer +and will fail to export any other weights. + +Furthermore, the quanitze option relies on the following explicit list of layers to be excluded: +{unquantized}. + +Modify this script manually if adjustments are needed. +""" + +parser = argparse.ArgumentParser(description=description) +parser.add_argument('weightfile', type=str, help='weight file path') +parser.add_argument('export_folder', type=str) +parser.add_argument('--export-filename', type=str, default='fargan_data', help='filename for source and header file (.c and .h will be added), defaults to fargan_data') +parser.add_argument('--struct-name', type=str, default='FARGAN', help='name for C struct, defaults to FARGAN') +parser.add_argument('--quantize', action='store_true', help='apply quantization') + +if __name__ == "__main__": + args = parser.parse_args() + + print(f"loading weights from {args.weightfile}...") + saved_gen= torch.load(args.weightfile, map_location='cpu') + saved_gen['model_args'] = () + saved_gen['model_kwargs'] = {'cond_size': 256, 'gamma': 0.9} + + model = fargan.FARGAN(*saved_gen['model_args'], **saved_gen['model_kwargs']) + model.load_state_dict(saved_gen['state_dict'], strict=False) + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + model.apply(_remove_weight_norm) + + + print("dumping model...") + quantize_model=args.quantize + + output_folder = args.export_folder + os.makedirs(output_folder, exist_ok=True) + + writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name) + + for name, module in model.named_modules(): + + if quantize_model: + quantize=name not in unquantized + scale = None if quantize else 1/128 + else: + quantize=False + scale=1/128 + + if isinstance(module, nn.Linear): + print(f"dumping linear layer {name}...") + wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale) + + elif isinstance(module, nn.Conv1d): + print(f"dumping conv1d layer {name}...") + wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale) + + elif isinstance(module, nn.GRU): + print(f"dumping GRU layer {name}...") + wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale) + + elif isinstance(module, nn.GRUCell): + print(f"dumping GRUCell layer {name}...") + wexchange.torch.dump_torch_grucell_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale) + + elif isinstance(module, nn.Embedding): + print(f"dumping Embedding layer {name}...") + wexchange.torch.dump_torch_embedding_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale) + #wexchange.torch.dump_torch_embedding_weights(writer, module) + + else: + print(f"Ignoring layer {name}...") + + writer.close() |