diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-12-18 14:19:55 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2024-01-20 16:44:22 +0300 |
commit | 299e38cab774fa4bd9708581210af8b09c6b5e4e (patch) | |
tree | 9c532b4579da306b0ed2915c0a90c29e8ba8b47a /dnn/torch/osce/export_model_weights.py | |
parent | 4f311a1ad44f1b7bd60e32984ca0604c46b6c593 (diff) |
Updated LACE and NoLACE models to version 2opus-ng-osce-models-v2
Diffstat (limited to 'dnn/torch/osce/export_model_weights.py')
-rw-r--r-- | dnn/torch/osce/export_model_weights.py | 30 |
1 files changed, 16 insertions, 14 deletions
diff --git a/dnn/torch/osce/export_model_weights.py b/dnn/torch/osce/export_model_weights.py index f94431d3..0bec9604 100644 --- a/dnn/torch/osce/export_model_weights.py +++ b/dnn/torch/osce/export_model_weights.py @@ -43,6 +43,7 @@ from models import model_dict from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d from utils.layers.td_shaper import TDShaper +from utils.misc import remove_all_weight_norm from wexchange.torch import dump_torch_weights @@ -58,30 +59,30 @@ schedules = { 'nolace': [ ('pitch_embedding', dict()), ('feature_net.conv1', dict()), - ('feature_net.conv2', dict(quantize=True, scale=None)), - ('feature_net.tconv', dict(quantize=True, scale=None)), - ('feature_net.gru', dict()), + ('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)), + ('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)), + ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)), ('cf1', dict(quantize=True, scale=None)), ('cf2', dict(quantize=True, scale=None)), ('af1', dict(quantize=True, scale=None)), - ('tdshape1', dict()), - ('tdshape2', dict()), - ('tdshape3', dict()), + ('tdshape1', dict(quantize=True, scale=None)), + ('tdshape2', dict(quantize=True, scale=None)), + ('tdshape3', dict(quantize=True, scale=None)), ('af2', dict(quantize=True, scale=None)), ('af3', dict(quantize=True, scale=None)), ('af4', dict(quantize=True, scale=None)), - ('post_cf1', dict(quantize=True, scale=None)), - ('post_cf2', dict(quantize=True, scale=None)), - ('post_af1', dict(quantize=True, scale=None)), - ('post_af2', dict(quantize=True, scale=None)), - ('post_af3', dict(quantize=True, scale=None)) + ('post_cf1', dict(quantize=True, scale=None, sparse=True)), + ('post_cf2', dict(quantize=True, scale=None, sparse=True)), + ('post_af1', dict(quantize=True, scale=None, sparse=True)), + ('post_af2', dict(quantize=True, scale=None, sparse=True)), + ('post_af3', dict(quantize=True, scale=None, sparse=True)) ], 'lace' : [ ('pitch_embedding', dict()), ('feature_net.conv1', dict()), - ('feature_net.conv2', dict(quantize=True, scale=None)), - ('feature_net.tconv', dict(quantize=True, scale=None)), - ('feature_net.gru', dict()), + ('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)), + ('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)), + ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)), ('cf1', dict(quantize=True, scale=None)), ('cf2', dict(quantize=True, scale=None)), ('af1', dict(quantize=True, scale=None)) @@ -140,6 +141,7 @@ if __name__ == "__main__": checkpoint = torch.load(checkpoint_path, map_location='cpu') model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs']) model.load_state_dict(checkpoint['state_dict']) + remove_all_weight_norm(model, verbose=True) # CWriter model_name = checkpoint['setup']['model']['name'] |