diff options
Diffstat (limited to 'dnn/torch/osce/export_model_weights.py')
-rw-r--r-- | dnn/torch/osce/export_model_weights.py | 30 |
1 files changed, 16 insertions, 14 deletions
diff --git a/dnn/torch/osce/export_model_weights.py b/dnn/torch/osce/export_model_weights.py index f94431d3..0bec9604 100644 --- a/dnn/torch/osce/export_model_weights.py +++ b/dnn/torch/osce/export_model_weights.py @@ -43,6 +43,7 @@ from models import model_dict from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d from utils.layers.td_shaper import TDShaper +from utils.misc import remove_all_weight_norm from wexchange.torch import dump_torch_weights @@ -58,30 +59,30 @@ schedules = { 'nolace': [ ('pitch_embedding', dict()), ('feature_net.conv1', dict()), - ('feature_net.conv2', dict(quantize=True, scale=None)), - ('feature_net.tconv', dict(quantize=True, scale=None)), - ('feature_net.gru', dict()), + ('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)), + ('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)), + ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)), ('cf1', dict(quantize=True, scale=None)), ('cf2', dict(quantize=True, scale=None)), ('af1', dict(quantize=True, scale=None)), - ('tdshape1', dict()), - ('tdshape2', dict()), - ('tdshape3', dict()), + ('tdshape1', dict(quantize=True, scale=None)), + ('tdshape2', dict(quantize=True, scale=None)), + ('tdshape3', dict(quantize=True, scale=None)), ('af2', dict(quantize=True, scale=None)), ('af3', dict(quantize=True, scale=None)), ('af4', dict(quantize=True, scale=None)), - ('post_cf1', dict(quantize=True, scale=None)), - ('post_cf2', dict(quantize=True, scale=None)), - ('post_af1', dict(quantize=True, scale=None)), - ('post_af2', dict(quantize=True, scale=None)), - ('post_af3', dict(quantize=True, scale=None)) + ('post_cf1', dict(quantize=True, scale=None, sparse=True)), + ('post_cf2', dict(quantize=True, scale=None, sparse=True)), + ('post_af1', dict(quantize=True, scale=None, sparse=True)), + ('post_af2', dict(quantize=True, scale=None, sparse=True)), + ('post_af3', dict(quantize=True, scale=None, sparse=True)) ], 'lace' : [ ('pitch_embedding', dict()), ('feature_net.conv1', dict()), - ('feature_net.conv2', dict(quantize=True, scale=None)), - ('feature_net.tconv', dict(quantize=True, scale=None)), - ('feature_net.gru', dict()), + ('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)), + ('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)), + ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)), ('cf1', dict(quantize=True, scale=None)), ('cf2', dict(quantize=True, scale=None)), ('af1', dict(quantize=True, scale=None)) @@ -140,6 +141,7 @@ if __name__ == "__main__": checkpoint = torch.load(checkpoint_path, map_location='cpu') model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs']) model.load_state_dict(checkpoint['state_dict']) + remove_all_weight_norm(model, verbose=True) # CWriter model_name = checkpoint['setup']['model']['name'] |