Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-12-22 00:57:35 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-12-22 00:57:35 +0300
commitc40add59af065f4fdf80048f2dad91d6b4480114 (patch)
treeccc4c9bc1e5802949fe9ed46551875355233d43e
parent627aa7f5b3688ba787c69e55e199ba82e2013be0 (diff)
lossgen: can now dump weights
-rw-r--r--dnn/torch/lossgen/export_lossgen.py101
-rw-r--r--dnn/torch/lossgen/lossgen.py5
-rw-r--r--dnn/torch/lossgen/test_lossgen.py3
-rw-r--r--dnn/torch/lossgen/train_lossgen.py10
4 files changed, 109 insertions, 10 deletions
diff --git a/dnn/torch/lossgen/export_lossgen.py b/dnn/torch/lossgen/export_lossgen.py
new file mode 100644
index 00000000..1f7df957
--- /dev/null
+++ b/dnn/torch/lossgen/export_lossgen.py
@@ -0,0 +1,101 @@
+"""
+/* Copyright (c) 2022 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+import sys
+
+sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('checkpoint', type=str, help='model checkpoint')
+parser.add_argument('output_dir', type=str, help='output folder')
+
+args = parser.parse_args()
+
+import torch
+import numpy as np
+
+import lossgen
+from wexchange.torch import dump_torch_weights
+from wexchange.c_export import CWriter, print_vector
+
+def c_export(args, model):
+
+ message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
+
+ writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen')
+ writer.header.write(
+f"""
+#include "opus_types.h"
+"""
+ )
+
+ dense_layers = [
+ ('dense_in', "lossgen_dense_in"),
+ ('dense_out', "lossgen_dense_out")
+ ]
+
+
+ for name, export_name in dense_layers:
+ layer = model.get_submodule(name)
+ dump_torch_weights(writer, layer, name=export_name, verbose=True, quantize=False, scale=None)
+
+
+ gru_layers = [
+ ("gru1", "lossgen_gru1"),
+ ("gru2", "lossgen_gru2"),
+ ]
+
+ max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=True, scale=None, recurrent_scale=None)
+ for name, export_name in gru_layers])
+
+ writer.header.write(
+f"""
+
+#define LOSSGEN_MAX_RNN_UNITS {max_rnn_units}
+
+"""
+ )
+
+ writer.close()
+
+
+if __name__ == "__main__":
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
+ model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
+ #model = LossGen()
+ #checkpoint = torch.load(args.checkpoint, map_location='cpu')
+ #model.load_state_dict(checkpoint['state_dict'])
+ c_export(args, model)
diff --git a/dnn/torch/lossgen/lossgen.py b/dnn/torch/lossgen/lossgen.py
index a1f2708b..9025165c 100644
--- a/dnn/torch/lossgen/lossgen.py
+++ b/dnn/torch/lossgen/lossgen.py
@@ -8,7 +8,8 @@ class LossGen(nn.Module):
self.gru1_size = gru1_size
self.gru2_size = gru2_size
- self.gru1 = nn.GRU(2, self.gru1_size, batch_first=True)
+ self.dense_in = nn.Linear(2, 8)
+ self.gru1 = nn.GRU(8, self.gru1_size, batch_first=True)
self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True)
self.dense_out = nn.Linear(self.gru2_size, 1)
@@ -22,7 +23,7 @@ class LossGen(nn.Module):
else:
gru1_state = states[0]
gru2_state = states[1]
- x = torch.cat([loss, perc], dim=-1)
+ x = torch.tanh(self.dense_in(torch.cat([loss, perc], dim=-1)))
gru1_out, gru1_state = self.gru1(x, gru1_state)
gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
return self.dense_out(gru2_out), [gru1_state, gru2_state]
diff --git a/dnn/torch/lossgen/test_lossgen.py b/dnn/torch/lossgen/test_lossgen.py
index 0258d0e6..95659b1f 100644
--- a/dnn/torch/lossgen/test_lossgen.py
+++ b/dnn/torch/lossgen/test_lossgen.py
@@ -18,10 +18,7 @@ args = parser.parse_args()
checkpoint = torch.load(args.model, map_location='cpu')
-
model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
-
-
model.load_state_dict(checkpoint['state_dict'], strict=False)
states=None
diff --git a/dnn/torch/lossgen/train_lossgen.py b/dnn/torch/lossgen/train_lossgen.py
index f0f6dd75..26e0f012 100644
--- a/dnn/torch/lossgen/train_lossgen.py
+++ b/dnn/torch/lossgen/train_lossgen.py
@@ -32,13 +32,13 @@ class LossDataset(torch.utils.data.Dataset):
return [self.loss[index, :, :], self.perc[index, :, :]+r0+r1]
-adam_betas = [0.8, 0.99]
+adam_betas = [0.8, 0.98]
adam_eps = 1e-8
-batch_size=512
-lr_decay = 0.0001
-lr = 0.001
+batch_size=256
+lr_decay = 0.001
+lr = 0.003
epsilon = 1e-5
-epochs = 20
+epochs = 2000
checkpoint_dir='checkpoint'
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = dict()