diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-12-22 00:57:35 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-12-22 00:57:35 +0300 |
commit | c40add59af065f4fdf80048f2dad91d6b4480114 (patch) | |
tree | ccc4c9bc1e5802949fe9ed46551875355233d43e | |
parent | 627aa7f5b3688ba787c69e55e199ba82e2013be0 (diff) |
lossgen: can now dump weights
-rw-r--r-- | dnn/torch/lossgen/export_lossgen.py | 101 | ||||
-rw-r--r-- | dnn/torch/lossgen/lossgen.py | 5 | ||||
-rw-r--r-- | dnn/torch/lossgen/test_lossgen.py | 3 | ||||
-rw-r--r-- | dnn/torch/lossgen/train_lossgen.py | 10 |
4 files changed, 109 insertions, 10 deletions
diff --git a/dnn/torch/lossgen/export_lossgen.py b/dnn/torch/lossgen/export_lossgen.py new file mode 100644 index 00000000..1f7df957 --- /dev/null +++ b/dnn/torch/lossgen/export_lossgen.py @@ -0,0 +1,101 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange')) + + +parser = argparse.ArgumentParser() + +parser.add_argument('checkpoint', type=str, help='model checkpoint') +parser.add_argument('output_dir', type=str, help='output folder') + +args = parser.parse_args() + +import torch +import numpy as np + +import lossgen +from wexchange.torch import dump_torch_weights +from wexchange.c_export import CWriter, print_vector + +def c_export(args, model): + + message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}" + + writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen') + writer.header.write( +f""" +#include "opus_types.h" +""" + ) + + dense_layers = [ + ('dense_in', "lossgen_dense_in"), + ('dense_out', "lossgen_dense_out") + ] + + + for name, export_name in dense_layers: + layer = model.get_submodule(name) + dump_torch_weights(writer, layer, name=export_name, verbose=True, quantize=False, scale=None) + + + gru_layers = [ + ("gru1", "lossgen_gru1"), + ("gru2", "lossgen_gru2"), + ] + + max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=True, scale=None, recurrent_scale=None) + for name, export_name in gru_layers]) + + writer.header.write( +f""" + +#define LOSSGEN_MAX_RNN_UNITS {max_rnn_units} + +""" + ) + + writer.close() + + +if __name__ == "__main__": + + os.makedirs(args.output_dir, exist_ok=True) + checkpoint = torch.load(args.checkpoint, map_location='cpu') + model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs']) + model.load_state_dict(checkpoint['state_dict'], strict=False) + #model = LossGen() + #checkpoint = torch.load(args.checkpoint, map_location='cpu') + #model.load_state_dict(checkpoint['state_dict']) + c_export(args, model) diff --git a/dnn/torch/lossgen/lossgen.py b/dnn/torch/lossgen/lossgen.py index a1f2708b..9025165c 100644 --- a/dnn/torch/lossgen/lossgen.py +++ b/dnn/torch/lossgen/lossgen.py @@ -8,7 +8,8 @@ class LossGen(nn.Module): self.gru1_size = gru1_size self.gru2_size = gru2_size - self.gru1 = nn.GRU(2, self.gru1_size, batch_first=True) + self.dense_in = nn.Linear(2, 8) + self.gru1 = nn.GRU(8, self.gru1_size, batch_first=True) self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True) self.dense_out = nn.Linear(self.gru2_size, 1) @@ -22,7 +23,7 @@ class LossGen(nn.Module): else: gru1_state = states[0] gru2_state = states[1] - x = torch.cat([loss, perc], dim=-1) + x = torch.tanh(self.dense_in(torch.cat([loss, perc], dim=-1))) gru1_out, gru1_state = self.gru1(x, gru1_state) gru2_out, gru2_state = self.gru2(gru1_out, gru2_state) return self.dense_out(gru2_out), [gru1_state, gru2_state] diff --git a/dnn/torch/lossgen/test_lossgen.py b/dnn/torch/lossgen/test_lossgen.py index 0258d0e6..95659b1f 100644 --- a/dnn/torch/lossgen/test_lossgen.py +++ b/dnn/torch/lossgen/test_lossgen.py @@ -18,10 +18,7 @@ args = parser.parse_args() checkpoint = torch.load(args.model, map_location='cpu') - model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs']) - - model.load_state_dict(checkpoint['state_dict'], strict=False) states=None diff --git a/dnn/torch/lossgen/train_lossgen.py b/dnn/torch/lossgen/train_lossgen.py index f0f6dd75..26e0f012 100644 --- a/dnn/torch/lossgen/train_lossgen.py +++ b/dnn/torch/lossgen/train_lossgen.py @@ -32,13 +32,13 @@ class LossDataset(torch.utils.data.Dataset): return [self.loss[index, :, :], self.perc[index, :, :]+r0+r1] -adam_betas = [0.8, 0.99] +adam_betas = [0.8, 0.98] adam_eps = 1e-8 -batch_size=512 -lr_decay = 0.0001 -lr = 0.001 +batch_size=256 +lr_decay = 0.001 +lr = 0.003 epsilon = 1e-5 -epochs = 20 +epochs = 2000 checkpoint_dir='checkpoint' os.makedirs(checkpoint_dir, exist_ok=True) checkpoint = dict() |