diff options
author | Jean-Marc Valin <jmvalin@jmvalin.ca> | 2024-01-17 10:26:48 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@jmvalin.ca> | 2024-01-17 10:26:48 +0300 |
commit | 4f311a1ad44f1b7bd60e32984ca0604c46b6c593 (patch) | |
tree | 7bc6041a00e98dd1ff926253e68cffb2c32ece6f | |
parent | 26ddfd713537accce773acc12f565021f4f6d28c (diff) |
PLC export script
mostly untested
-rw-r--r-- | dnn/torch/plc/export_plc.py | 100 | ||||
-rw-r--r-- | dnn/torch/plc/train_plc.py | 2 |
2 files changed, 101 insertions, 1 deletions
diff --git a/dnn/torch/plc/export_plc.py b/dnn/torch/plc/export_plc.py new file mode 100644 index 00000000..7f153c4c --- /dev/null +++ b/dnn/torch/plc/export_plc.py @@ -0,0 +1,100 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange')) + + +parser = argparse.ArgumentParser() + +parser.add_argument('checkpoint', type=str, help='model checkpoint') +parser.add_argument('output_dir', type=str, help='output folder') + +args = parser.parse_args() + +import torch +import numpy as np + +import plc +from wexchange.torch import dump_torch_weights +from wexchange.c_export import CWriter, print_vector + +def c_export(args, model): + + message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}" + + writer = CWriter(os.path.join(args.output_dir, "plc_data"), message=message, model_struct_name='PLCModel') + writer.header.write( +f""" +#include "opus_types.h" +""" + ) + + dense_layers = [ + ('dense_in', "plc_dense_in"), + ('dense_out', "plc_dense_out") + ] + + + for name, export_name in dense_layers: + layer = model.get_submodule(name) + dump_torch_weights(writer, layer, name=export_name, verbose=True, quantize=False, scale=None) + + + gru_layers = [ + ("gru1", "plc_gru1"), + ("gru2", "plc_gru2"), + ] + + max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=True, scale=None, recurrent_scale=None) + for name, export_name in gru_layers]) + + writer.header.write( +f""" + +#define PLC_MAX_RNN_UNITS {max_rnn_units} + +""" + ) + + writer.close() + + +if __name__ == "__main__": + + os.makedirs(args.output_dir, exist_ok=True) + checkpoint = torch.load(args.checkpoint, map_location='cpu') + model = plc.PLC(*checkpoint['model_args'], **checkpoint['model_kwargs']) + model.load_state_dict(checkpoint['state_dict'], strict=False) + #checkpoint = torch.load(args.checkpoint, map_location='cpu') + #model.load_state_dict(checkpoint['state_dict']) + c_export(args, model) diff --git a/dnn/torch/plc/train_plc.py b/dnn/torch/plc/train_plc.py index 97be2c04..12b31c4e 100644 --- a/dnn/torch/plc/train_plc.py +++ b/dnn/torch/plc/train_plc.py @@ -138,7 +138,7 @@ if __name__ == '__main__': ) # save checkpoint - checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_{epoch}.pth') + checkpoint_path = os.path.join(checkpoint_dir, f'plc{args.suffix}_{epoch}.pth') checkpoint['state_dict'] = model.state_dict() checkpoint['loss'] = running_loss / len(dataloader) checkpoint['epoch'] = epoch |